package server

import (
	
	
	
	
	
	
	
	
	
	
	

	
)

// StdioContextFunc is a function that takes an existing context and returns
// a potentially modified context.
// This can be used to inject context values from environment variables,
// for example.
type StdioContextFunc func(ctx context.Context) context.Context

// StdioServer wraps a MCPServer and handles stdio communication.
// It provides a simple way to create command-line MCP servers that
// communicate via standard input/output streams using JSON-RPC messages.
type StdioServer struct {
	server      *MCPServer
	errLogger   *log.Logger
	contextFunc StdioContextFunc

	// Thread-safe tool call processing
	toolCallQueue  chan *toolCallWork
	workerWg       sync.WaitGroup
	workerPoolSize int
	queueSize      int
	writeMu        sync.Mutex // Protects concurrent writes
}

// toolCallWork represents a queued tool call request
type toolCallWork struct {
	ctx     context.Context
	message json.RawMessage
	writer  io.Writer
}

// StdioOption defines a function type for configuring StdioServer
type StdioOption func(*StdioServer)

// WithErrorLogger sets the error logger for the server
func ( *log.Logger) StdioOption {
	return func( *StdioServer) {
		.errLogger = 
	}
}

// WithStdioContextFunc sets a function that will be called to customise the context
// to the server. Note that the stdio server uses the same context for all requests,
// so this function will only be called once per server instance.
func ( StdioContextFunc) StdioOption {
	return func( *StdioServer) {
		.contextFunc = 
	}
}

// WithWorkerPoolSize sets the number of workers for processing tool calls
func ( int) StdioOption {
	return func( *StdioServer) {
		const  = 100
		if  > 0 &&  <=  {
			.workerPoolSize = 
		} else if  >  {
			.errLogger.Printf("Worker pool size %d exceeds maximum (%d), using maximum", , )
			.workerPoolSize = 
		}
	}
}

// WithQueueSize sets the size of the tool call queue
func ( int) StdioOption {
	return func( *StdioServer) {
		const  = 10000
		if  > 0 &&  <=  {
			.queueSize = 
		} else if  >  {
			.errLogger.Printf("Queue size %d exceeds maximum (%d), using maximum", , )
			.queueSize = 
		}
	}
}

// stdioSession is a static client session, since stdio has only one client.
type stdioSession struct {
	notifications       chan mcp.JSONRPCNotification
	initialized         atomic.Bool
	loggingLevel        atomic.Value
	clientInfo          atomic.Value                        // stores session-specific client info
	clientCapabilities  atomic.Value                        // stores session-specific client capabilities
	writer              io.Writer                           // for sending requests to client
	requestID           atomic.Int64                        // for generating unique request IDs
	mu                  sync.RWMutex                        // protects writer
	pendingRequests     map[int64]chan *samplingResponse    // for tracking pending sampling requests
	pendingElicitations map[int64]chan *elicitationResponse // for tracking pending elicitation requests
	pendingRoots        map[int64]chan *rootsResponse       // for tracking pending list roots requests
	pendingMu           sync.RWMutex                        // protects pendingRequests and pendingElicitations
}

// samplingResponse represents a response to a sampling request
type samplingResponse struct {
	result *mcp.CreateMessageResult
	err    error
}

// elicitationResponse represents a response to an elicitation request
type elicitationResponse struct {
	result *mcp.ElicitationResult
	err    error
}

// rootsResponse represents a response to an list root request
type rootsResponse struct {
	result *mcp.ListRootsResult
	err    error
}

func ( *stdioSession) () string {
	return "stdio"
}

func ( *stdioSession) () chan<- mcp.JSONRPCNotification {
	return .notifications
}

func ( *stdioSession) () {
	// set default logging level
	.loggingLevel.Store(mcp.LoggingLevelError)
	.initialized.Store(true)
}

func ( *stdioSession) () bool {
	return .initialized.Load()
}

func ( *stdioSession) () mcp.Implementation {
	if  := .clientInfo.Load();  != nil {
		if ,  := .(mcp.Implementation);  {
			return 
		}
	}
	return mcp.Implementation{}
}

func ( *stdioSession) ( mcp.Implementation) {
	.clientInfo.Store()
}

func ( *stdioSession) () mcp.ClientCapabilities {
	if  := .clientCapabilities.Load();  != nil {
		if ,  := .(mcp.ClientCapabilities);  {
			return 
		}
	}
	return mcp.ClientCapabilities{}
}

func ( *stdioSession) ( mcp.ClientCapabilities) {
	.clientCapabilities.Store()
}

func ( *stdioSession) ( mcp.LoggingLevel) {
	.loggingLevel.Store()
}

func ( *stdioSession) () mcp.LoggingLevel {
	 := .loggingLevel.Load()
	if  == nil {
		return mcp.LoggingLevelError
	}
	return .(mcp.LoggingLevel)
}

// RequestSampling sends a sampling request to the client and waits for the response.
func ( *stdioSession) ( context.Context,  mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
	.mu.RLock()
	 := .writer
	.mu.RUnlock()

	if  == nil {
		return nil, fmt.Errorf("no writer available for sending requests")
	}

	// Generate a unique request ID
	 := .requestID.Add(1)

	// Create a response channel for this request
	 := make(chan *samplingResponse, 1)
	.pendingMu.Lock()
	.pendingRequests[] = 
	.pendingMu.Unlock()

	// Cleanup function to remove the pending request
	 := func() {
		.pendingMu.Lock()
		delete(.pendingRequests, )
		.pendingMu.Unlock()
	}
	defer ()

	// Create the JSON-RPC request
	 := struct {
		 string                  `json:"jsonrpc"`
		      int64                   `json:"id"`
		  string                  `json:"method"`
		  mcp.CreateMessageParams `json:"params"`
	}{
		: mcp.JSONRPC_VERSION,
		:      ,
		:  string(mcp.MethodSamplingCreateMessage),
		:  .CreateMessageParams,
	}

	// Marshal and send the request
	,  := json.Marshal()
	if  != nil {
		return nil, fmt.Errorf("failed to marshal sampling request: %w", )
	}
	 = append(, '\n')

	if ,  := .Write();  != nil {
		return nil, fmt.Errorf("failed to write sampling request: %w", )
	}

	// Wait for the response or context cancellation
	select {
	case <-.Done():
		return nil, .Err()
	case  := <-:
		if .err != nil {
			return nil, .err
		}
		return .result, nil
	}
}

// ListRoots sends an list roots request to the client and waits for the response.
func ( *stdioSession) ( context.Context,  mcp.ListRootsRequest) (*mcp.ListRootsResult, error) {
	.mu.RLock()
	 := .writer
	.mu.RUnlock()

	if  == nil {
		return nil, fmt.Errorf("no writer available for sending requests")
	}

	// Generate a unique request ID
	 := .requestID.Add(1)

	// Create a response channel for this request
	 := make(chan *rootsResponse, 1)
	.pendingMu.Lock()
	.pendingRoots[] = 
	.pendingMu.Unlock()

	// Cleanup function to remove the pending request
	 := func() {
		.pendingMu.Lock()
		delete(.pendingRoots, )
		.pendingMu.Unlock()
	}
	defer ()

	// Create the JSON-RPC request
	 := struct {
		 string `json:"jsonrpc"`
		      int64  `json:"id"`
		  string `json:"method"`
	}{
		: mcp.JSONRPC_VERSION,
		:      ,
		:  string(mcp.MethodListRoots),
	}

	// Marshal and send the request
	,  := json.Marshal()
	if  != nil {
		return nil, fmt.Errorf("failed to marshal list roots request: %w", )
	}
	 = append(, '\n')

	if ,  := .Write();  != nil {
		return nil, fmt.Errorf("failed to write list roots request: %w", )
	}

	// Wait for the response or context cancellation
	select {
	case <-.Done():
		return nil, .Err()
	case  := <-:
		if .err != nil {
			return nil, .err
		}
		return .result, nil
	}
}

// RequestElicitation sends an elicitation request to the client and waits for the response.
func ( *stdioSession) ( context.Context,  mcp.ElicitationRequest) (*mcp.ElicitationResult, error) {
	.mu.RLock()
	 := .writer
	.mu.RUnlock()

	if  == nil {
		return nil, fmt.Errorf("no writer available for sending requests")
	}

	// Generate a unique request ID
	 := .requestID.Add(1)

	// Create a response channel for this request
	 := make(chan *elicitationResponse, 1)
	.pendingMu.Lock()
	.pendingElicitations[] = 
	.pendingMu.Unlock()

	// Cleanup function to remove the pending request
	 := func() {
		.pendingMu.Lock()
		delete(.pendingElicitations, )
		.pendingMu.Unlock()
	}
	defer ()

	// Create the JSON-RPC request
	 := struct {
		 string                `json:"jsonrpc"`
		      int64                 `json:"id"`
		  string                `json:"method"`
		  mcp.ElicitationParams `json:"params"`
	}{
		: mcp.JSONRPC_VERSION,
		:      ,
		:  string(mcp.MethodElicitationCreate),
		:  .Params,
	}

	// Marshal and send the request
	,  := json.Marshal()
	if  != nil {
		return nil, fmt.Errorf("failed to marshal elicitation request: %w", )
	}
	 = append(, '\n')

	if ,  := .Write();  != nil {
		return nil, fmt.Errorf("failed to write elicitation request: %w", )
	}

	// Wait for the response or context cancellation
	select {
	case <-.Done():
		return nil, .Err()
	case  := <-:
		if .err != nil {
			return nil, .err
		}
		return .result, nil
	}
}

// SetWriter sets the writer for sending requests to the client.
func ( *stdioSession) ( io.Writer) {
	.mu.Lock()
	defer .mu.Unlock()
	.writer = 
}

var (
	_ ClientSession          = (*stdioSession)(nil)
	_ SessionWithLogging     = (*stdioSession)(nil)
	_ SessionWithClientInfo  = (*stdioSession)(nil)
	_ SessionWithSampling    = (*stdioSession)(nil)
	_ SessionWithElicitation = (*stdioSession)(nil)
	_ SessionWithRoots       = (*stdioSession)(nil)
)

var stdioSessionInstance = stdioSession{
	notifications:       make(chan mcp.JSONRPCNotification, 100),
	pendingRequests:     make(map[int64]chan *samplingResponse),
	pendingElicitations: make(map[int64]chan *elicitationResponse),
	pendingRoots:        make(map[int64]chan *rootsResponse),
}

// NewStdioServer creates a new stdio server wrapper around an MCPServer.
// It initializes the server with a default error logger that discards all output.
func ( *MCPServer) *StdioServer {
	return &StdioServer{
		server: ,
		errLogger: log.New(
			os.Stderr,
			"",
			log.LstdFlags,
		), // Default to discarding logs
		workerPoolSize: 5,   // Default worker pool size
		queueSize:      100, // Default queue size
	}
}

// SetErrorLogger configures where error messages from the StdioServer are logged.
// The provided logger will receive all error messages generated during server operation.
func ( *StdioServer) ( *log.Logger) {
	.errLogger = 
}

// SetContextFunc sets a function that will be called to customise the context
// to the server. Note that the stdio server uses the same context for all requests,
// so this function will only be called once per server instance.
func ( *StdioServer) ( StdioContextFunc) {
	.contextFunc = 
}

// handleNotifications continuously processes notifications from the session's notification channel
// and writes them to the provided output. It runs until the context is cancelled.
// Any errors encountered while writing notifications are logged but do not stop the handler.
func ( *StdioServer) ( context.Context,  io.Writer) {
	for {
		select {
		case  := <-stdioSessionInstance.notifications:
			if  := .writeResponse(, );  != nil {
				.errLogger.Printf("Error writing notification: %v", )
			}
		case <-.Done():
			return
		}
	}
}

// processInputStream continuously reads and processes messages from the input stream.
// It handles EOF gracefully as a normal termination condition.
// The function returns when either:
// - The context is cancelled (returns context.Err())
// - EOF is encountered (returns nil)
// - An error occurs while reading or processing messages (returns the error)
func ( *StdioServer) ( context.Context,  *bufio.Reader,  io.Writer) error {
	for {
		if  := .Err();  != nil {
			return 
		}

		,  := .readNextLine(, )
		if  != nil {
			if  == io.EOF {
				return nil
			}
			.errLogger.Printf("Error reading input: %v", )
			return 
		}

		if  := .processMessage(, , );  != nil {
			if  == io.EOF {
				return nil
			}
			.errLogger.Printf("Error handling message: %v", )
			return 
		}
	}
}

// toolCallWorker processes tool calls from the queue
func ( *StdioServer) ( context.Context) {
	defer .workerWg.Done()

	for {
		select {
		case ,  := <-.toolCallQueue:
			if ! {
				// Channel closed, exit worker
				return
			}
			// Process the tool call
			 := .server.HandleMessage(.ctx, .message)
			if  != nil {
				if  := .writeResponse(, .writer);  != nil {
					.errLogger.Printf("Error writing tool response: %v", )
				}
			}
		case <-.Done():
			return
		}
	}
}

// readNextLine reads a single line from the input reader in a context-aware manner.
// It uses channels to make the read operation cancellable via context.
// Returns the read line and any error encountered. If the context is cancelled,
// returns an empty string and the context's error. EOF is returned when the input
// stream is closed.
func ( *StdioServer) ( context.Context,  *bufio.Reader) (string, error) {
	type  struct {
		 string
		  error
	}

	 := make(chan , 1)

	go func() {
		,  := .ReadString('\n')
		 <- {: , : }
	}()

	select {
	case <-.Done():
		return "", nil
	case  := <-:
		return ., .
	}
}

// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
// It runs until the context is cancelled or an error occurs.
// Returns an error if there are issues with reading input or writing output.
func ( *StdioServer) (
	 context.Context,
	 io.Reader,
	 io.Writer,
) error {
	// Initialize the tool call queue
	.toolCallQueue = make(chan *toolCallWork, .queueSize)

	// Set a static client context since stdio only has one client
	if  := .server.RegisterSession(, &stdioSessionInstance);  != nil {
		return fmt.Errorf("register session: %w", )
	}
	defer .server.UnregisterSession(, stdioSessionInstance.SessionID())
	 = .server.WithContext(, &stdioSessionInstance)

	// Set the writer for sending requests to the client
	stdioSessionInstance.SetWriter()

	// Add in any custom context.
	if .contextFunc != nil {
		 = .contextFunc()
	}

	 := bufio.NewReader()

	// Start worker pool for tool calls
	for  := 0;  < .workerPoolSize; ++ {
		.workerWg.Add(1)
		go .toolCallWorker()
	}

	// Start notification handler
	go .handleNotifications(, )

	// Process input stream
	 := .processInputStream(, , )

	// Shutdown workers gracefully
	close(.toolCallQueue)
	.workerWg.Wait()

	return 
}

// processMessage handles a single JSON-RPC message and writes the response.
// It parses the message, processes it through the wrapped MCPServer, and writes any response.
// Returns an error if there are issues with message processing or response writing.
func ( *StdioServer) (
	 context.Context,
	 string,
	 io.Writer,
) error {
	// If line is empty, likely due to ctx cancellation
	if len() == 0 {
		return nil
	}

	// Parse the message as raw JSON
	var  json.RawMessage
	if  := json.Unmarshal([]byte(), &);  != nil {
		 := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
		return .writeResponse(, )
	}

	// Check if this is a response to a sampling request
	if .handleSamplingResponse() {
		return nil
	}

	// Check if this is a response to an elicitation request
	if .handleElicitationResponse() {
		return nil
	}

	// Check if this is a response to an list roots request
	if .handleListRootsResponse() {
		return nil
	}

	// Check if this is a tool call that might need sampling (and thus should be processed concurrently)
	var  struct {
		 string `json:"method"`
	}
	if json.Unmarshal(, &) == nil && . == "tools/call" {
		// Queue tool calls for processing by workers
		select {
		case .toolCallQueue <- &toolCallWork{
			ctx:     ,
			message: ,
			writer:  ,
		}:
			return nil
		case <-.Done():
			return .Err()
		default:
			// Queue is full, process synchronously as fallback
			.errLogger.Printf("Tool call queue full, processing synchronously")
			 := .server.HandleMessage(, )
			if  != nil {
				return .writeResponse(, )
			}
			return nil
		}
	}

	// Handle other messages synchronously
	 := .server.HandleMessage(, )

	// Only write response if there is one (not for notifications)
	if  != nil {
		if  := .writeResponse(, );  != nil {
			return fmt.Errorf("failed to write response: %w", )
		}
	}

	return nil
}

// handleSamplingResponse checks if the message is a response to a sampling request
// and routes it to the appropriate pending request channel.
func ( *StdioServer) ( json.RawMessage) bool {
	return stdioSessionInstance.handleSamplingResponse()
}

// handleSamplingResponse handles incoming sampling responses for this session
func ( *stdioSession) ( json.RawMessage) bool {
	// Try to parse as a JSON-RPC response
	var  struct {
		 string          `json:"jsonrpc"`
		      json.Number     `json:"id"`
		  json.RawMessage `json:"result,omitempty"`
		   *struct {
			    int    `json:"code"`
			 string `json:"message"`
		} `json:"error,omitempty"`
	}

	if  := json.Unmarshal(, &);  != nil {
		return false
	}
	// Parse the ID as int64
	,  := ..Int64()
	if  != nil || (. == nil && . == nil) {
		return false
	}

	// Look for a pending request with this ID
	.pendingMu.RLock()
	,  := .pendingRequests[]
	.pendingMu.RUnlock()

	if ! {
		return false
	} // Parse and send the response
	 := &samplingResponse{}

	if . != nil {
		.err = fmt.Errorf("sampling request failed: %s", ..)
	} else {
		var  mcp.CreateMessageResult
		if  := json.Unmarshal(., &);  != nil {
			.err = fmt.Errorf("failed to unmarshal sampling response: %w", )
		} else {
			// Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent)
			if ,  := .Content.(map[string]any);  {
				,  := mcp.ParseContent()
				if  != nil {
					.err = fmt.Errorf("failed to parse sampling response content: %w", )
				} else {
					.Content = 
					.result = &
				}
			} else {
				.result = &
			}
		}
	}

	// Send the response (non-blocking)
	select {
	case  <- :
	default:
		// Channel is full or closed, ignore
	}

	return true
}

// handleElicitationResponse checks if the message is a response to an elicitation request
// and routes it to the appropriate pending request channel.
func ( *StdioServer) ( json.RawMessage) bool {
	return stdioSessionInstance.handleElicitationResponse()
}

// handleElicitationResponse handles incoming elicitation responses for this session
func ( *stdioSession) ( json.RawMessage) bool {
	// Try to parse as a JSON-RPC response
	var  struct {
		 string          `json:"jsonrpc"`
		      json.Number     `json:"id"`
		  json.RawMessage `json:"result,omitempty"`
		   *struct {
			    int    `json:"code"`
			 string `json:"message"`
		} `json:"error,omitempty"`
	}

	if  := json.Unmarshal(, &);  != nil {
		return false
	}
	// Parse the ID as int64
	,  := ..Int64()
	if  != nil || (. == nil && . == nil) {
		return false
	}

	// Check if we have a pending elicitation request with this ID
	.pendingMu.RLock()
	,  := .pendingElicitations[]
	.pendingMu.RUnlock()

	if ! {
		return false
	}

	// Parse and send the response
	 := &elicitationResponse{}

	if . != nil {
		.err = fmt.Errorf("elicitation request failed: %s", ..)
	} else {
		var  mcp.ElicitationResult
		if  := json.Unmarshal(., &);  != nil {
			.err = fmt.Errorf("failed to unmarshal elicitation response: %w", )
		} else {
			.result = &
		}
	}

	// Send the response (non-blocking)
	select {
	case  <- :
	default:
		// Channel is full or closed, ignore
	}

	return true
}

// handleListRootsResponse checks if the message is a response to an list roots request
// and routes it to the appropriate pending request channel.
func ( *StdioServer) ( json.RawMessage) bool {
	return stdioSessionInstance.handleListRootsResponse()
}

// handleListRootsResponse handles incoming list root responses for this session
func ( *stdioSession) ( json.RawMessage) bool {
	// Try to parse as a JSON-RPC response
	var  struct {
		 string          `json:"jsonrpc"`
		      json.Number     `json:"id"`
		  json.RawMessage `json:"result,omitempty"`
		   *struct {
			    int    `json:"code"`
			 string `json:"message"`
		} `json:"error,omitempty"`
	}

	if  := json.Unmarshal(, &);  != nil {
		return false
	}
	// Parse the ID as int64
	,  := ..Int64()
	if  != nil || (. == nil && . == nil) {
		return false
	}

	// Check if we have a pending list root request with this ID
	.pendingMu.RLock()
	,  := .pendingRoots[]
	.pendingMu.RUnlock()

	if ! {
		return false
	}

	// Parse and send the response
	 := &rootsResponse{}

	if . != nil {
		.err = fmt.Errorf("list root request failed: %s", ..)
	} else {
		var  mcp.ListRootsResult
		if  := json.Unmarshal(., &);  != nil {
			.err = fmt.Errorf("failed to unmarshal list root response: %w", )
		} else {
			.result = &
		}
	}

	// Send the response (non-blocking)
	select {
	case  <- :
	default:
		// Channel is full or closed, ignore
	}

	return true
}

// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
// Returns an error if marshaling or writing fails.
func ( *StdioServer) (
	 mcp.JSONRPCMessage,
	 io.Writer,
) error {
	,  := json.Marshal()
	if  != nil {
		return 
	}

	// Protect concurrent writes
	.writeMu.Lock()
	defer .writeMu.Unlock()

	// Write response followed by newline
	if ,  := fmt.Fprintf(, "%s\n", );  != nil {
		return 
	}

	return nil
}

// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
// Returns an error if the server encounters any issues during operation.
func ( *MCPServer,  ...StdioOption) error {
	 := NewStdioServer()

	for ,  := range  {
		()
	}

	,  := context.WithCancel(context.Background())
	defer ()

	// Set up signal handling
	 := make(chan os.Signal, 1)
	signal.Notify(, syscall.SIGTERM, syscall.SIGINT)

	go func() {
		<-
		()
	}()

	return .Listen(, os.Stdin, os.Stdout)
}