package server

import (
	
	
	
	
	
	
	
	
	
	
	
	

	

	
)

// sseSession represents an active SSE connection.
type sseSession struct {
	done                chan struct{}
	eventQueue          chan string // Channel for queuing events
	sessionID           string
	requestID           atomic.Int64
	notificationChannel chan mcp.JSONRPCNotification
	initialized         atomic.Bool
	loggingLevel        atomic.Value
	tools               sync.Map     // stores session-specific tools
	resources           sync.Map     // stores session-specific resources
	resourceTemplates   sync.Map     // stores session-specific resource templates
	clientInfo          atomic.Value // stores session-specific client info
	clientCapabilities  atomic.Value // stores session-specific client capabilities
}

// SSEContextFunc is a function that takes an existing context and the current
// request and returns a potentially modified context based on the request
// content. This can be used to inject context values from headers, for example.
type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context

// DynamicBasePathFunc allows the user to provide a function to generate the
// base path for a given request and sessionID. This is useful for cases where
// the base path is not known at the time of SSE server creation, such as when
// using a reverse proxy or when the base path is dynamically generated. The
// function should return the base path (e.g., "/mcp/tenant123").
type DynamicBasePathFunc func(r *http.Request, sessionID string) string

// SessionIDGenFunc is a function that produces a session ID for a new SSE connection.
// It receives the request context and the HTTP request, and should return a session
// identifier (string) or an error.
type SessionIDGenFunc func(ctx context.Context, r *http.Request) (string, error)

func ( *sseSession) () string {
	return .sessionID
}

func ( *sseSession) () chan<- mcp.JSONRPCNotification {
	return .notificationChannel
}

func ( *sseSession) () {
	// set default logging level
	.loggingLevel.Store(mcp.LoggingLevelError)
	.initialized.Store(true)
}

func ( *sseSession) () bool {
	return .initialized.Load()
}

func ( *sseSession) ( mcp.LoggingLevel) {
	.loggingLevel.Store()
}

func ( *sseSession) () mcp.LoggingLevel {
	 := .loggingLevel.Load()
	if  == nil {
		return mcp.LoggingLevelError
	}
	return .(mcp.LoggingLevel)
}

func ( *sseSession) () map[string]ServerResource {
	 := make(map[string]ServerResource)
	.resources.Range(func(,  any) bool {
		if ,  := .(ServerResource);  {
			[.(string)] = 
		}
		return true
	})
	return 
}

func ( *sseSession) ( map[string]ServerResource) {
	// Clear existing resources
	.resources.Clear()

	// Set new resources
	for ,  := range  {
		.resources.Store(, )
	}
}

func ( *sseSession) () map[string]ServerResourceTemplate {
	 := make(map[string]ServerResourceTemplate)
	.resourceTemplates.Range(func(,  any) bool {
		if ,  := .(ServerResourceTemplate);  {
			[.(string)] = 
		}
		return true
	})
	return 
}

func ( *sseSession) ( map[string]ServerResourceTemplate) {
	// Clear existing templates
	.resourceTemplates.Clear()

	// Set new templates
	for ,  := range  {
		.resourceTemplates.Store(, )
	}
}

func ( *sseSession) () map[string]ServerTool {
	 := make(map[string]ServerTool)
	.tools.Range(func(,  any) bool {
		if ,  := .(ServerTool);  {
			[.(string)] = 
		}
		return true
	})
	return 
}

func ( *sseSession) ( map[string]ServerTool) {
	// Clear existing tools
	.tools.Clear()

	// Set new tools
	for ,  := range  {
		.tools.Store(, )
	}
}

func ( *sseSession) () mcp.Implementation {
	if  := .clientInfo.Load();  != nil {
		if ,  := .(mcp.Implementation);  {
			return 
		}
	}
	return mcp.Implementation{}
}

func ( *sseSession) ( mcp.Implementation) {
	.clientInfo.Store()
}

func ( *sseSession) ( mcp.ClientCapabilities) {
	.clientCapabilities.Store()
}

func ( *sseSession) () mcp.ClientCapabilities {
	if  := .clientCapabilities.Load();  != nil {
		if ,  := .(mcp.ClientCapabilities);  {
			return 
		}
	}
	return mcp.ClientCapabilities{}
}

var (
	_ ClientSession                = (*sseSession)(nil)
	_ SessionWithTools             = (*sseSession)(nil)
	_ SessionWithResources         = (*sseSession)(nil)
	_ SessionWithResourceTemplates = (*sseSession)(nil)
	_ SessionWithLogging           = (*sseSession)(nil)
	_ SessionWithClientInfo        = (*sseSession)(nil)
)

// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
	server                       *MCPServer
	baseURL                      string
	basePath                     string
	appendQueryToMessageEndpoint bool
	useFullURLForMessageEndpoint bool
	messageEndpoint              string
	sseEndpoint                  string
	sessions                     sync.Map
	srv                          *http.Server
	contextFunc                  SSEContextFunc
	dynamicBasePathFunc          DynamicBasePathFunc
	sessionIDGenFunc             SessionIDGenFunc

	keepAlive         bool
	keepAliveInterval time.Duration

	mu sync.RWMutex
}

// SSEOption defines a function type for configuring SSEServer
type SSEOption func(*SSEServer)

// WithBaseURL sets the base URL for the SSE server
func ( string) SSEOption {
	return func( *SSEServer) {
		if  != "" {
			,  := url.Parse()
			if  != nil {
				return
			}
			if .Scheme != "http" && .Scheme != "https" {
				return
			}
			// Check if the host is empty or only contains a port
			if .Host == "" || strings.HasPrefix(.Host, ":") {
				return
			}
			if len(.Query()) > 0 {
				return
			}
		}
		.baseURL = strings.TrimSuffix(, "/")
	}
}

// WithStaticBasePath adds a new option for setting a static base path
func ( string) SSEOption {
	return func( *SSEServer) {
		.basePath = normalizeURLPath()
	}
}

// WithBasePath adds a new option for setting a static base path.
//
// Deprecated: Use WithStaticBasePath instead. This will be removed in a future version.
//
//go:deprecated
func ( string) SSEOption {
	return WithStaticBasePath()
}

// WithDynamicBasePath accepts a function for generating the base path. This is
// useful for cases where the base path is not known at the time of SSE server
// creation, such as when using a reverse proxy or when the server is mounted
// at a dynamic path.
func ( DynamicBasePathFunc) SSEOption {
	return func( *SSEServer) {
		if  != nil {
			.dynamicBasePathFunc = func( *http.Request,  string) string {
				 := (, )
				return normalizeURLPath()
			}
		}
	}
}

// WithMessageEndpoint sets the message endpoint path
func ( string) SSEOption {
	return func( *SSEServer) {
		.messageEndpoint = 
	}
}

// WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's
// query parameters to the message endpoint URL that is sent to clients during the SSE connection
// initialization. This is useful when you need to preserve query parameters from the initial
// SSE connection request and carry them over to subsequent message requests, maintaining
// context or authentication details across the communication channel.
func () SSEOption {
	return func( *SSEServer) {
		.appendQueryToMessageEndpoint = true
	}
}

// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL)
// or just the path portion for the message endpoint. Set to false when clients will concatenate
// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message".
func ( bool) SSEOption {
	return func( *SSEServer) {
		.useFullURLForMessageEndpoint = 
	}
}

// WithSSEEndpoint sets the SSE endpoint path
func ( string) SSEOption {
	return func( *SSEServer) {
		.sseEndpoint = 
	}
}

// WithHTTPServer sets the HTTP server instance.
// NOTE: When providing a custom HTTP server, you must handle routing yourself
// If routing is not set up, the server will start but won't handle any MCP requests.
func ( *http.Server) SSEOption {
	return func( *SSEServer) {
		.srv = 
	}
}

func ( time.Duration) SSEOption {
	return func( *SSEServer) {
		.keepAlive = true
		.keepAliveInterval = 
	}
}

func ( bool) SSEOption {
	return func( *SSEServer) {
		.keepAlive = 
	}
}

// WithSSEContextFunc sets a function that will be called to customise the context
// to the server using the incoming request.
func ( SSEContextFunc) SSEOption {
	return func( *SSEServer) {
		.contextFunc = 
	}
}

// WithSessionIDGenerator sets a custom session ID generator. If fn == nil the call is ignored.
func ( SessionIDGenFunc) SSEOption {
	return func( *SSEServer) {
		if  != nil {
			.sessionIDGenFunc = 
		}
	}
}

// NewSSEServer creates a new SSE server instance with the given MCP server and options.
func ( *MCPServer,  ...SSEOption) *SSEServer {
	 := &SSEServer{
		server:                       ,
		sseEndpoint:                  "/sse",
		messageEndpoint:              "/message",
		useFullURLForMessageEndpoint: true,
		keepAlive:                    false,
		keepAliveInterval:            10 * time.Second,
		sessionIDGenFunc: func( context.Context,  *http.Request) (string, error) {
			return uuid.New().String(), nil
		},
	}

	// Apply all options
	for ,  := range  {
		()
	}

	return 
}

// NewTestServer creates a test server for testing purposes
func ( *MCPServer,  ...SSEOption) *httptest.Server {
	 := NewSSEServer(, ...)

	 := httptest.NewServer()
	.baseURL = .URL
	return 
}

// Start begins serving SSE connections on the specified address.
// It sets up HTTP handlers for SSE and message endpoints.
func ( *SSEServer) ( string) error {
	.mu.Lock()
	if .srv == nil {
		.srv = &http.Server{
			Addr:    ,
			Handler: ,
		}
	} else {
		if .srv.Addr == "" {
			.srv.Addr = 
		} else if .srv.Addr !=  {
			return fmt.Errorf("conflicting listen address: WithHTTPServer(%q) vs Start(%q)", .srv.Addr, )
		}
	}
	 := .srv
	.mu.Unlock()

	return .ListenAndServe()
}

// Shutdown gracefully stops the SSE server, closing all active sessions
// and shutting down the HTTP server.
func ( *SSEServer) ( context.Context) error {
	.mu.RLock()
	 := .srv
	.mu.RUnlock()

	if  != nil {
		.sessions.Range(func(,  any) bool {
			if ,  := .(*sseSession);  {
				close(.done)
			}
			.sessions.Delete()
			return true
		})

		return .Shutdown()
	}
	return nil
}

// handleSSE handles incoming SSE connection requests.
// It sets up appropriate headers and creates a new session for the client.
func ( *SSEServer) ( http.ResponseWriter,  *http.Request) {
	if .Method != http.MethodGet {
		http.Error(, "Method not allowed", http.StatusMethodNotAllowed)
		return
	}

	.Header().Set("Content-Type", "text/event-stream")
	.Header().Set("Cache-Control", "no-cache")
	.Header().Set("Connection", "keep-alive")
	.Header().Set("Access-Control-Allow-Origin", "*")

	,  := .(http.Flusher)
	if ! {
		http.Error(, "Streaming unsupported", http.StatusInternalServerError)
		return
	}

	,  := .sessionIDGenFunc(.Context(), )
	if  != nil {
		http.Error(, "Failed to create session ID", http.StatusInternalServerError)
		return
	}
	if  == "" {
		http.Error(, "Failed to create session ID", http.StatusInternalServerError)
		return
	}

	 := &sseSession{
		done:                make(chan struct{}),
		eventQueue:          make(chan string, 100), // Buffer for events
		sessionID:           ,
		notificationChannel: make(chan mcp.JSONRPCNotification, 100),
	}

	.sessions.Store(, )
	defer .sessions.Delete()

	if  := .server.RegisterSession(.Context(), );  != nil {
		http.Error(
			,
			fmt.Sprintf("Session registration failed: %v", ),
			http.StatusInternalServerError,
		)
		return
	}
	defer .server.UnregisterSession(.Context(), )

	// Start notification handler for this session
	go func() {
		for {
			select {
			case  := <-.notificationChannel:
				,  := json.Marshal()
				if  == nil {
					select {
					case .eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", ):
						// Event queued successfully
					case <-.done:
						return
					}
				}
			case <-.done:
				return
			case <-.Context().Done():
				return
			}
		}
	}()

	// Start keep alive : ping
	if .keepAlive {
		go func() {
			 := time.NewTicker(.keepAliveInterval)
			defer .Stop()
			for {
				select {
				case <-.C:
					 := mcp.JSONRPCRequest{
						JSONRPC: "2.0",
						ID:      mcp.NewRequestId(.requestID.Add(1)),
						Request: mcp.Request{
							Method: "ping",
						},
					}
					,  := json.Marshal()
					 := fmt.Sprintf("event: message\ndata:%s\n\n", )
					select {
					case .eventQueue <- :
						// Message sent successfully
					case <-.done:
						return
					}
				case <-.done:
					return
				case <-.Context().Done():
					return
				}
			}
		}()
	}

	// Send the initial endpoint event
	 := .GetMessageEndpointForClient(, )
	if .appendQueryToMessageEndpoint && len(.URL.RawQuery) > 0 {
		 += "&" + .URL.RawQuery
	}
	fmt.Fprintf(, "event: endpoint\ndata: %s\r\n\r\n", )
	.Flush()

	// Main event loop - this runs in the HTTP handler goroutine
	for {
		select {
		case  := <-.eventQueue:
			// Write the event to the response
			fmt.Fprint(, )
			.Flush()
		case <-.Context().Done():
			close(.done)
			return
		case <-.done:
			return
		}
	}
}

// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID
// for the given request. This is the canonical way to compute the message endpoint for a client.
// It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag.
func ( *SSEServer) ( *http.Request,  string) string {
	 := .basePath
	if .dynamicBasePathFunc != nil {
		 = .dynamicBasePathFunc(, )
	}

	 := normalizeURLPath(, .messageEndpoint)
	if .useFullURLForMessageEndpoint && .baseURL != "" {
		 = .baseURL + 
	}

	return fmt.Sprintf("%s?sessionId=%s", , )
}

// handleMessage processes incoming JSON-RPC messages from clients and sends responses
// back through the SSE connection and 202 code to HTTP response.
func ( *SSEServer) ( http.ResponseWriter,  *http.Request) {
	if .Method != http.MethodPost {
		.writeJSONRPCError(, nil, mcp.INVALID_REQUEST, "Method not allowed")
		return
	}

	 := .URL.Query().Get("sessionId")
	if  == "" {
		.writeJSONRPCError(, nil, mcp.INVALID_PARAMS, "Missing sessionId")
		return
	}
	,  := .sessions.Load()
	if ! {
		.writeJSONRPCError(, nil, mcp.INVALID_PARAMS, "Invalid session ID")
		return
	}
	 := .(*sseSession)

	// Set the client context before handling the message
	 := .server.WithContext(.Context(), )
	if .contextFunc != nil {
		 = .contextFunc(, )
	}

	// Parse message as raw JSON
	var  json.RawMessage
	if  := json.NewDecoder(.Body).Decode(&);  != nil {
		.writeJSONRPCError(, nil, mcp.PARSE_ERROR, "Parse error")
		return
	}

	// Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled.
	// this is required because the http ctx will be canceled when the client disconnects
	 := context.WithoutCancel()

	// quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE
	.WriteHeader(http.StatusAccepted)

	// Create a new context for handling the message that will be canceled when the message handling is done
	 := context.WithValue(, requestHeader, .Header)
	,  := context.WithCancel()

	go func( context.Context) {
		defer ()
		// Use the context that will be canceled when session is done
		// Process message through MCPServer
		 := .server.HandleMessage(, )
		// Only send response if there is one (not for notifications)
		if  != nil {
			var  string
			if ,  := json.Marshal();  != nil {
				// If there is an error marshalling the response, send a generic error response
				log.Printf("failed to marshal response: %v", )
				 = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n"
			} else {
				 = fmt.Sprintf("event: message\ndata: %s\n\n", )
			}

			// Queue the event for sending via SSE
			select {
			case .eventQueue <- :
				// Event queued successfully
			case <-.done:
				// Session is closed, don't try to queue
			default:
				// Queue is full, log this situation
				log.Printf("Event queue full for session %s", )
			}
		}
	}()
}

// writeJSONRPCError writes a JSON-RPC error response with the given error details.
func ( *SSEServer) (
	 http.ResponseWriter,
	 any,
	 int,
	 string,
) {
	 := createErrorResponse(, , )
	.Header().Set("Content-Type", "application/json")
	.WriteHeader(http.StatusBadRequest)
	if  := json.NewEncoder().Encode();  != nil {
		http.Error(
			,
			fmt.Sprintf("Failed to encode response: %v", ),
			http.StatusInternalServerError,
		)
		return
	}
}

// SendEventToSession sends an event to a specific SSE session identified by sessionID.
// Returns an error if the session is not found or closed.
func ( *SSEServer) (
	 string,
	 any,
) error {
	,  := .sessions.Load()
	if ! {
		return fmt.Errorf("session not found: %s", )
	}
	 := .(*sseSession)

	,  := json.Marshal()
	if  != nil {
		return 
	}

	// Queue the event for sending via SSE
	select {
	case .eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", ):
		return nil
	case <-.done:
		return fmt.Errorf("session closed")
	default:
		return fmt.Errorf("event queue full")
	}
}

func ( *SSEServer) ( string) (string, error) {
	,  := url.Parse()
	if  != nil {
		return "", fmt.Errorf("failed to parse URL %s: %w", , )
	}
	return .Path, nil
}

func ( *SSEServer) () (string, error) {
	if .dynamicBasePathFunc != nil {
		return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"}
	}

	 := normalizeURLPath(.basePath, .sseEndpoint)
	return .baseURL + , nil
}

func ( *SSEServer) () string {
	,  := .CompleteSseEndpoint()
	if  != nil {
		return normalizeURLPath(.basePath, .sseEndpoint)
	}
	,  := .GetUrlPath()
	if  != nil {
		return normalizeURLPath(.basePath, .sseEndpoint)
	}
	return 
}

func ( *SSEServer) () (string, error) {
	if .dynamicBasePathFunc != nil {
		return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"}
	}
	 := normalizeURLPath(.basePath, .messageEndpoint)
	return .baseURL + , nil
}

func ( *SSEServer) () string {
	,  := .CompleteMessageEndpoint()
	if  != nil {
		return normalizeURLPath(.basePath, .messageEndpoint)
	}
	,  := .GetUrlPath()
	if  != nil {
		return normalizeURLPath(.basePath, .messageEndpoint)
	}
	return 
}

// SSEHandler returns an http.Handler for the SSE endpoint.
//
// This method allows you to mount the SSE handler at any arbitrary path
// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
// intended for advanced scenarios where you want to control the routing or
// support dynamic segments.
//
// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
// you must use the WithDynamicBasePath option to ensure the correct base path
// is communicated to clients.
//
// Example usage:
//
//	// Advanced/dynamic:
//	sseServer := NewSSEServer(mcpServer,
//		WithDynamicBasePath(func(r *http.Request, sessionID string) string {
//			tenant := r.PathValue("tenant")
//			return "/mcp/" + tenant
//		}),
//		WithBaseURL("http://localhost:8080")
//	)
//	mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
//	mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
//
// For non-dynamic cases, use ServeHTTP method instead.
func ( *SSEServer) () http.Handler {
	return http.HandlerFunc(.handleSSE)
}

// MessageHandler returns an http.Handler for the message endpoint.
//
// This method allows you to mount the message handler at any arbitrary path
// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
// intended for advanced scenarios where you want to control the routing or
// support dynamic segments.
//
// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
// you must use the WithDynamicBasePath option to ensure the correct base path
// is communicated to clients.
//
// Example usage:
//
//	// Advanced/dynamic:
//	sseServer := NewSSEServer(mcpServer,
//		WithDynamicBasePath(func(r *http.Request, sessionID string) string {
//			tenant := r.PathValue("tenant")
//			return "/mcp/" + tenant
//		}),
//		WithBaseURL("http://localhost:8080")
//	)
//	mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
//	mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
//
// For non-dynamic cases, use ServeHTTP method instead.
func ( *SSEServer) () http.Handler {
	return http.HandlerFunc(.handleMessage)
}

// ServeHTTP implements the http.Handler interface.
func ( *SSEServer) ( http.ResponseWriter,  *http.Request) {
	if .dynamicBasePathFunc != nil {
		http.Error(
			,
			(&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(),
			http.StatusInternalServerError,
		)
		return
	}
	 := .URL.Path
	// Use exact path matching rather than Contains
	 := .CompleteSsePath()
	if  != "" &&  ==  {
		.handleSSE(, )
		return
	}
	 := .CompleteMessagePath()
	if  != "" &&  ==  {
		.handleMessage(, )
		return
	}

	http.NotFound(, )
}

// normalizeURLPath joins path elements like path.Join but ensures the
// result always starts with a leading slash and never ends with a slash
func normalizeURLPath( ...string) string {
	 := path.Join(...)

	// Ensure leading slash
	if !strings.HasPrefix(, "/") {
		 = "/" + 
	}

	// Remove trailing slash if not just "/"
	if len() > 1 && strings.HasSuffix(, "/") {
		 = [:len()-1]
	}

	return 
}