package webtransport

import (
	
	
	
	
	
	
	
	
	
	
	
	

	
	
	

	
)

const (
	wtAvailableProtocolsHeader = "WT-Available-Protocols"
	wtProtocolHeader           = "WT-Protocol"
)

const (
	webTransportFrameType     = 0x41
	webTransportUniStreamType = 0x54
)

type quicConnKeyType struct{}

var quicConnKey = quicConnKeyType{}

func ( *http3.Server) {
	if .AdditionalSettings == nil {
		.AdditionalSettings = make(map[uint64]uint64, 1)
	}
	.AdditionalSettings[settingsEnableWebtransport] = 1
	.EnableDatagrams = true
	 := .ConnContext
	.ConnContext = func( context.Context,  *quic.Conn) context.Context {
		if  != nil {
			 = (, )
		}
		 = context.WithValue(, quicConnKey, )
		return 
	}
}

type Server struct {
	H3 *http3.Server

	// ApplicationProtocols is a list of application protocols that can be negotiated,
	// see section 3.3 of https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-14 for details.
	ApplicationProtocols []string

	// ReorderingTimeout is the maximum time an incoming WebTransport stream that cannot be associated
	// with a session is buffered. It is also the maximum time a WebTransport connection request is
	// blocked waiting for the client's SETTINGS are received.
	// This can happen if the CONNECT request (that creates a new session) is reordered, and arrives
	// after the first WebTransport stream(s) for that session.
	// Defaults to 5 seconds.
	ReorderingTimeout time.Duration

	// CheckOrigin is used to validate the request origin, thereby preventing cross-site request forgery.
	// CheckOrigin returns true if the request Origin header is acceptable.
	// If unset, a safe default is used: If the Origin header is set, it is checked that it
	// matches the request's Host header.
	CheckOrigin func(r *http.Request) bool

	ctx       context.Context // is closed when Close is called
	ctxCancel context.CancelFunc
	refCount  sync.WaitGroup

	initOnce sync.Once
	initErr  error

	connsMx sync.Mutex
	conns   map[*quic.Conn]*sessionManager
}

func ( *Server) () error {
	.initOnce.Do(func() {
		.initErr = .init()
	})
	return .initErr
}

func ( *Server) () time.Duration {
	 := .ReorderingTimeout
	if  == 0 {
		return 5 * time.Second
	}
	return 
}

func ( *Server) () error {
	.ctx, .ctxCancel = context.WithCancel(context.Background())

	.conns = make(map[*quic.Conn]*sessionManager)
	if .CheckOrigin == nil {
		.CheckOrigin = checkSameOrigin
	}
	return nil
}

func ( *Server) ( net.PacketConn) error {
	if  := .initialize();  != nil {
		return 
	}
	var  *quic.Config
	if .H3.QUICConfig != nil {
		 = .H3.QUICConfig.Clone()
	} else {
		 = &quic.Config{}
	}
	.EnableDatagrams = true
	.EnableStreamResetPartialDelivery = true
	,  := quic.ListenEarly(, .H3.TLSConfig, )
	if  != nil {
		return 
	}
	defer .Close()

	for {
		,  := .Accept(.ctx)
		if  != nil {
			return 
		}
		.refCount.Add(1)
		go func() {
			defer .refCount.Done()

			if  := .ServeQUICConn();  != nil {
				log.Printf("http3: error serving QUIC connection: %v", )
			}
		}()
	}
}

// ServeQUICConn serves a single QUIC connection.
func ( *Server) ( *quic.Conn) error {
	 := .ConnectionState()
	if !.SupportsDatagrams.Local {
		return errors.New("webtransport: QUIC DATAGRAM support required, enable it via QUICConfig.EnableDatagrams")
	}
	if !.SupportsStreamResetPartialDelivery.Local {
		return errors.New("webtransport: QUIC Stream Resets with Partial Delivery required, enable it via QUICConfig.EnableStreamResetPartialDelivery")
	}
	if  := .initialize();  != nil {
		return 
	}

	.connsMx.Lock()
	,  := .conns[]
	if ! {
		 = newSessionManager(.timeout())
		.conns[] = 
	}
	.connsMx.Unlock()

	// Clean up when connection closes
	context.AfterFunc(.Context(), func() {
		.connsMx.Lock()
		delete(.conns, )
		.connsMx.Unlock()
		.Close()
	})

	,  := .H3.NewRawServerConn()
	if  != nil {
		return 
	}

	// slose the connection when the server context is cancelled.
	go func() {
		select {
		case <-.ctx.Done():
			.CloseWithError(0, "")
		case <-.Context().Done():
			// connection already closed
		}
	}()

	var  sync.WaitGroup
	.Add(2)
	go func() {
		defer .Done()

		for {
			,  := .AcceptStream(.ctx)
			if  != nil {
				return
			}

			.Add(1)
			go func() {
				defer .Done()

				,  := quicvarint.Peek()
				if  != nil {
					return
				}
				if  != webTransportFrameType {
					.HandleRequestStream()
					return
				}
				// read the frame type (already peeked)
				if ,  := quicvarint.Read(quicvarint.NewReader());  != nil {
					return
				}
				// read the session ID
				,  := quicvarint.Read(quicvarint.NewReader())
				if  != nil {
					.CancelRead(quic.StreamErrorCode(http3.ErrCodeGeneralProtocolError))
					.CancelWrite(quic.StreamErrorCode(http3.ErrCodeGeneralProtocolError))
					return
				}
				.AddStream(, sessionID())
			}()
		}
	}()

	go func() {
		defer .Done()

		for {
			,  := .AcceptUniStream(.ctx)
			if  != nil {
				return
			}

			.Add(1)
			go func() {
				defer .Done()

				,  := quicvarint.Peek()
				if  != nil {
					return
				}
				if  != webTransportUniStreamType {
					.HandleUnidirectionalStream()
					return
				}
				// read the stream type (already peeked) before passing to AddUniStream
				 := quicvarint.NewReader()
				if ,  := quicvarint.Read();  != nil {
					return
				}
				// read the session ID
				,  := quicvarint.Read()
				if  != nil {
					.CancelRead(quic.StreamErrorCode(http3.ErrCodeGeneralProtocolError))
					return
				}
				.AddUniStream(, sessionID())
			}()
		}
	}()

	.Wait()
	return nil
}

func ( *Server) () error {
	 := .H3.Addr
	if  == "" {
		 = ":https"
	}
	,  := net.ResolveUDPAddr("udp", )
	if  != nil {
		return 
	}
	,  := net.ListenUDP("udp", )
	if  != nil {
		return 
	}
	return .Serve()
}

func ( *Server) (,  string) error {
	,  := tls.LoadX509KeyPair(, )
	if  != nil {
		return 
	}
	if .H3.TLSConfig == nil {
		.H3.TLSConfig = &tls.Config{}
	}
	.H3.TLSConfig.Certificates = []tls.Certificate{}
	return .ListenAndServe()
}

func ( *Server) () error {
	// Make sure that ctxCancel is defined.
	// This is expected to be uncommon.
	// It only happens if the server is closed without Serve / ListenAndServe having been called.
	.initOnce.Do(func() {})

	if .ctxCancel != nil {
		.ctxCancel()
	}
	.connsMx.Lock()
	if .conns != nil {
		for ,  := range .conns {
			.Close()
		}
		.conns = nil
	}
	.connsMx.Unlock()

	 := .H3.Close()
	.refCount.Wait()
	return 
}

func ( *Server) ( http.ResponseWriter,  *http.Request) (*Session, error) {
	if  := .initialize();  != nil {
		return nil, 
	}
	if .Method != http.MethodConnect {
		return nil, fmt.Errorf("expected CONNECT request, got %s", .Method)
	}
	if .Proto != protocolHeader {
		return nil, fmt.Errorf("unexpected protocol: %s", .Proto)
	}
	if !.CheckOrigin() {
		return nil, errors.New("webtransport: request origin not allowed")
	}

	 := .Context().Value(quicConnKey)
	if  == nil {
		return nil, errors.New("webtransport: missing QUIC connection")
	}
	 := .(*quic.Conn)

	 := .selectProtocol(.Header[http.CanonicalHeaderKey(wtAvailableProtocolsHeader)])

	// Wait for SETTINGS
	 := .(http3.Settingser)
	 := time.NewTimer(.timeout())
	defer .Stop()
	select {
	case <-.ReceivedSettings():
	case <-.C:
		return nil, errors.New("webtransport: didn't receive the client's SETTINGS on time")
	}
	 := .Settings()
	if !.EnableDatagrams {
		return nil, errors.New("webtransport: missing datagram support")
	}

	if  != "" {
		,  := httpsfv.Marshal(httpsfv.NewItem())
		if  != nil {
			return nil, fmt.Errorf("failed to marshal selected protocol: %w", )
		}
		.Header().Add(wtProtocolHeader, )
	}
	.WriteHeader(http.StatusOK)
	.(http.Flusher).Flush()

	 := .(http3.HTTPStreamer).HTTPStream()
	 := sessionID(.StreamID())

	// The session manager should already exist because ServeQUICConn creates it
	// before any HTTP requests can be processed on this connection.
	.connsMx.Lock()
	defer .connsMx.Unlock()

	,  := .conns[]
	if ! {
		return nil, errors.New("webtransport: connection session manager not found")
	}

	 := newSession(context.WithoutCancel(.Context()), , , , )
	.AddSession(, )
	return , nil
}

func ( *Server) ( []string) string {
	,  := httpsfv.UnmarshalList()
	if  != nil {
		return ""
	}
	 := make([]string, 0, len())
	for ,  := range  {
		,  := .(httpsfv.Item)
		if ! {
			return ""
		}
		,  := .Value.(string)
		if ! {
			return ""
		}
		 = append(, )
	}
	var  string
	for ,  := range  {
		if slices.Contains(.ApplicationProtocols, ) {
			 = 
			break
		}
	}
	return 
}

// copied from https://github.com/gorilla/websocket
func checkSameOrigin( *http.Request) bool {
	 := .Header.Get("Origin")
	if  == "" {
		return true
	}
	,  := url.Parse()
	if  != nil {
		return false
	}
	return equalASCIIFold(.Host, .Host)
}

// copied from https://github.com/gorilla/websocket
func equalASCIIFold(,  string) bool {
	for  != "" &&  != "" {
		,  := utf8.DecodeRuneInString()
		 = [:]
		,  := utf8.DecodeRuneInString()
		 = [:]
		if  ==  {
			continue
		}
		if 'A' <=  &&  <= 'Z' {
			 =  + 'a' - 'A'
		}
		if 'A' <=  &&  <= 'Z' {
			 =  + 'a' - 'A'
		}
		if  !=  {
			return false
		}
	}
	return  == 
}