package server
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"github.com/mark3labs/mcp-go/mcp"
)
type StdioContextFunc func (ctx context .Context ) context .Context
type StdioServer struct {
server *MCPServer
errLogger *log .Logger
contextFunc StdioContextFunc
toolCallQueue chan *toolCallWork
workerWg sync .WaitGroup
workerPoolSize int
queueSize int
writeMu sync .Mutex
}
type toolCallWork struct {
ctx context .Context
message json .RawMessage
writer io .Writer
}
type StdioOption func (*StdioServer )
func WithErrorLogger (logger *log .Logger ) StdioOption {
return func (s *StdioServer ) {
s .errLogger = logger
}
}
func WithStdioContextFunc (fn StdioContextFunc ) StdioOption {
return func (s *StdioServer ) {
s .contextFunc = fn
}
}
func WithWorkerPoolSize (size int ) StdioOption {
return func (s *StdioServer ) {
const maxWorkerPoolSize = 100
if size > 0 && size <= maxWorkerPoolSize {
s .workerPoolSize = size
} else if size > maxWorkerPoolSize {
s .errLogger .Printf ("Worker pool size %d exceeds maximum (%d), using maximum" , size , maxWorkerPoolSize )
s .workerPoolSize = maxWorkerPoolSize
}
}
}
func WithQueueSize (size int ) StdioOption {
return func (s *StdioServer ) {
const maxQueueSize = 10000
if size > 0 && size <= maxQueueSize {
s .queueSize = size
} else if size > maxQueueSize {
s .errLogger .Printf ("Queue size %d exceeds maximum (%d), using maximum" , size , maxQueueSize )
s .queueSize = maxQueueSize
}
}
}
type stdioSession struct {
notifications chan mcp .JSONRPCNotification
initialized atomic .Bool
loggingLevel atomic .Value
clientInfo atomic .Value
clientCapabilities atomic .Value
writer io .Writer
requestID atomic .Int64
mu sync .RWMutex
pendingRequests map [int64 ]chan *samplingResponse
pendingElicitations map [int64 ]chan *elicitationResponse
pendingRoots map [int64 ]chan *rootsResponse
pendingMu sync .RWMutex
}
type samplingResponse struct {
result *mcp .CreateMessageResult
err error
}
type elicitationResponse struct {
result *mcp .ElicitationResult
err error
}
type rootsResponse struct {
result *mcp .ListRootsResult
err error
}
func (s *stdioSession ) SessionID () string {
return "stdio"
}
func (s *stdioSession ) NotificationChannel () chan <- mcp .JSONRPCNotification {
return s .notifications
}
func (s *stdioSession ) Initialize () {
s .loggingLevel .Store (mcp .LoggingLevelError )
s .initialized .Store (true )
}
func (s *stdioSession ) Initialized () bool {
return s .initialized .Load ()
}
func (s *stdioSession ) GetClientInfo () mcp .Implementation {
if value := s .clientInfo .Load (); value != nil {
if clientInfo , ok := value .(mcp .Implementation ); ok {
return clientInfo
}
}
return mcp .Implementation {}
}
func (s *stdioSession ) SetClientInfo (clientInfo mcp .Implementation ) {
s .clientInfo .Store (clientInfo )
}
func (s *stdioSession ) GetClientCapabilities () mcp .ClientCapabilities {
if value := s .clientCapabilities .Load (); value != nil {
if clientCapabilities , ok := value .(mcp .ClientCapabilities ); ok {
return clientCapabilities
}
}
return mcp .ClientCapabilities {}
}
func (s *stdioSession ) SetClientCapabilities (clientCapabilities mcp .ClientCapabilities ) {
s .clientCapabilities .Store (clientCapabilities )
}
func (s *stdioSession ) SetLogLevel (level mcp .LoggingLevel ) {
s .loggingLevel .Store (level )
}
func (s *stdioSession ) GetLogLevel () mcp .LoggingLevel {
level := s .loggingLevel .Load ()
if level == nil {
return mcp .LoggingLevelError
}
return level .(mcp .LoggingLevel )
}
func (s *stdioSession ) RequestSampling (ctx context .Context , request mcp .CreateMessageRequest ) (*mcp .CreateMessageResult , error ) {
s .mu .RLock ()
writer := s .writer
s .mu .RUnlock ()
if writer == nil {
return nil , fmt .Errorf ("no writer available for sending requests" )
}
id := s .requestID .Add (1 )
responseChan := make (chan *samplingResponse , 1 )
s .pendingMu .Lock ()
s .pendingRequests [id ] = responseChan
s .pendingMu .Unlock ()
cleanup := func () {
s .pendingMu .Lock ()
delete (s .pendingRequests , id )
s .pendingMu .Unlock ()
}
defer cleanup ()
jsonRPCRequest := struct {
JSONRPC string `json:"jsonrpc"`
ID int64 `json:"id"`
Method string `json:"method"`
Params mcp .CreateMessageParams `json:"params"`
}{
JSONRPC : mcp .JSONRPC_VERSION ,
ID : id ,
Method : string (mcp .MethodSamplingCreateMessage ),
Params : request .CreateMessageParams ,
}
requestBytes , err := json .Marshal (jsonRPCRequest )
if err != nil {
return nil , fmt .Errorf ("failed to marshal sampling request: %w" , err )
}
requestBytes = append (requestBytes , '\n' )
if _ , err := writer .Write (requestBytes ); err != nil {
return nil , fmt .Errorf ("failed to write sampling request: %w" , err )
}
select {
case <- ctx .Done ():
return nil , ctx .Err ()
case response := <- responseChan :
if response .err != nil {
return nil , response .err
}
return response .result , nil
}
}
func (s *stdioSession ) ListRoots (ctx context .Context , request mcp .ListRootsRequest ) (*mcp .ListRootsResult , error ) {
s .mu .RLock ()
writer := s .writer
s .mu .RUnlock ()
if writer == nil {
return nil , fmt .Errorf ("no writer available for sending requests" )
}
id := s .requestID .Add (1 )
responseChan := make (chan *rootsResponse , 1 )
s .pendingMu .Lock ()
s .pendingRoots [id ] = responseChan
s .pendingMu .Unlock ()
cleanup := func () {
s .pendingMu .Lock ()
delete (s .pendingRoots , id )
s .pendingMu .Unlock ()
}
defer cleanup ()
jsonRPCRequest := struct {
JSONRPC string `json:"jsonrpc"`
ID int64 `json:"id"`
Method string `json:"method"`
}{
JSONRPC : mcp .JSONRPC_VERSION ,
ID : id ,
Method : string (mcp .MethodListRoots ),
}
requestBytes , err := json .Marshal (jsonRPCRequest )
if err != nil {
return nil , fmt .Errorf ("failed to marshal list roots request: %w" , err )
}
requestBytes = append (requestBytes , '\n' )
if _ , err := writer .Write (requestBytes ); err != nil {
return nil , fmt .Errorf ("failed to write list roots request: %w" , err )
}
select {
case <- ctx .Done ():
return nil , ctx .Err ()
case response := <- responseChan :
if response .err != nil {
return nil , response .err
}
return response .result , nil
}
}
func (s *stdioSession ) RequestElicitation (ctx context .Context , request mcp .ElicitationRequest ) (*mcp .ElicitationResult , error ) {
s .mu .RLock ()
writer := s .writer
s .mu .RUnlock ()
if writer == nil {
return nil , fmt .Errorf ("no writer available for sending requests" )
}
id := s .requestID .Add (1 )
responseChan := make (chan *elicitationResponse , 1 )
s .pendingMu .Lock ()
s .pendingElicitations [id ] = responseChan
s .pendingMu .Unlock ()
cleanup := func () {
s .pendingMu .Lock ()
delete (s .pendingElicitations , id )
s .pendingMu .Unlock ()
}
defer cleanup ()
jsonRPCRequest := struct {
JSONRPC string `json:"jsonrpc"`
ID int64 `json:"id"`
Method string `json:"method"`
Params mcp .ElicitationParams `json:"params"`
}{
JSONRPC : mcp .JSONRPC_VERSION ,
ID : id ,
Method : string (mcp .MethodElicitationCreate ),
Params : request .Params ,
}
requestBytes , err := json .Marshal (jsonRPCRequest )
if err != nil {
return nil , fmt .Errorf ("failed to marshal elicitation request: %w" , err )
}
requestBytes = append (requestBytes , '\n' )
if _ , err := writer .Write (requestBytes ); err != nil {
return nil , fmt .Errorf ("failed to write elicitation request: %w" , err )
}
select {
case <- ctx .Done ():
return nil , ctx .Err ()
case response := <- responseChan :
if response .err != nil {
return nil , response .err
}
return response .result , nil
}
}
func (s *stdioSession ) SetWriter (writer io .Writer ) {
s .mu .Lock ()
defer s .mu .Unlock ()
s .writer = writer
}
var (
_ ClientSession = (*stdioSession )(nil )
_ SessionWithLogging = (*stdioSession )(nil )
_ SessionWithClientInfo = (*stdioSession )(nil )
_ SessionWithSampling = (*stdioSession )(nil )
_ SessionWithElicitation = (*stdioSession )(nil )
_ SessionWithRoots = (*stdioSession )(nil )
)
var stdioSessionInstance = stdioSession {
notifications : make (chan mcp .JSONRPCNotification , 100 ),
pendingRequests : make (map [int64 ]chan *samplingResponse ),
pendingElicitations : make (map [int64 ]chan *elicitationResponse ),
pendingRoots : make (map [int64 ]chan *rootsResponse ),
}
func NewStdioServer (server *MCPServer ) *StdioServer {
return &StdioServer {
server : server ,
errLogger : log .New (
os .Stderr ,
"" ,
log .LstdFlags ,
),
workerPoolSize : 5 ,
queueSize : 100 ,
}
}
func (s *StdioServer ) SetErrorLogger (logger *log .Logger ) {
s .errLogger = logger
}
func (s *StdioServer ) SetContextFunc (fn StdioContextFunc ) {
s .contextFunc = fn
}
func (s *StdioServer ) handleNotifications (ctx context .Context , stdout io .Writer ) {
for {
select {
case notification := <- stdioSessionInstance .notifications :
if err := s .writeResponse (notification , stdout ); err != nil {
s .errLogger .Printf ("Error writing notification: %v" , err )
}
case <- ctx .Done ():
return
}
}
}
func (s *StdioServer ) processInputStream (ctx context .Context , reader *bufio .Reader , stdout io .Writer ) error {
for {
if err := ctx .Err (); err != nil {
return err
}
line , err := s .readNextLine (ctx , reader )
if err != nil {
if err == io .EOF {
return nil
}
s .errLogger .Printf ("Error reading input: %v" , err )
return err
}
if err := s .processMessage (ctx , line , stdout ); err != nil {
if err == io .EOF {
return nil
}
s .errLogger .Printf ("Error handling message: %v" , err )
return err
}
}
}
func (s *StdioServer ) toolCallWorker (ctx context .Context ) {
defer s .workerWg .Done ()
for {
select {
case work , ok := <- s .toolCallQueue :
if !ok {
return
}
response := s .server .HandleMessage (work .ctx , work .message )
if response != nil {
if err := s .writeResponse (response , work .writer ); err != nil {
s .errLogger .Printf ("Error writing tool response: %v" , err )
}
}
case <- ctx .Done ():
return
}
}
}
func (s *StdioServer ) readNextLine (ctx context .Context , reader *bufio .Reader ) (string , error ) {
type result struct {
line string
err error
}
resultCh := make (chan result , 1 )
go func () {
line , err := reader .ReadString ('\n' )
resultCh <- result {line : line , err : err }
}()
select {
case <- ctx .Done ():
return "" , nil
case res := <- resultCh :
return res .line , res .err
}
}
func (s *StdioServer ) Listen (
ctx context .Context ,
stdin io .Reader ,
stdout io .Writer ,
) error {
s .toolCallQueue = make (chan *toolCallWork , s .queueSize )
if err := s .server .RegisterSession (ctx , &stdioSessionInstance ); err != nil {
return fmt .Errorf ("register session: %w" , err )
}
defer s .server .UnregisterSession (ctx , stdioSessionInstance .SessionID ())
ctx = s .server .WithContext (ctx , &stdioSessionInstance )
stdioSessionInstance .SetWriter (stdout )
if s .contextFunc != nil {
ctx = s .contextFunc (ctx )
}
reader := bufio .NewReader (stdin )
for i := 0 ; i < s .workerPoolSize ; i ++ {
s .workerWg .Add (1 )
go s .toolCallWorker (ctx )
}
go s .handleNotifications (ctx , stdout )
err := s .processInputStream (ctx , reader , stdout )
close (s .toolCallQueue )
s .workerWg .Wait ()
return err
}
func (s *StdioServer ) processMessage (
ctx context .Context ,
line string ,
writer io .Writer ,
) error {
if len (line ) == 0 {
return nil
}
var rawMessage json .RawMessage
if err := json .Unmarshal ([]byte (line ), &rawMessage ); err != nil {
response := createErrorResponse (nil , mcp .PARSE_ERROR , "Parse error" )
return s .writeResponse (response , writer )
}
if s .handleSamplingResponse (rawMessage ) {
return nil
}
if s .handleElicitationResponse (rawMessage ) {
return nil
}
if s .handleListRootsResponse (rawMessage ) {
return nil
}
var baseMessage struct {
Method string `json:"method"`
}
if json .Unmarshal (rawMessage , &baseMessage ) == nil && baseMessage .Method == "tools/call" {
select {
case s .toolCallQueue <- &toolCallWork {
ctx : ctx ,
message : rawMessage ,
writer : writer ,
}:
return nil
case <- ctx .Done ():
return ctx .Err ()
default :
s .errLogger .Printf ("Tool call queue full, processing synchronously" )
response := s .server .HandleMessage (ctx , rawMessage )
if response != nil {
return s .writeResponse (response , writer )
}
return nil
}
}
response := s .server .HandleMessage (ctx , rawMessage )
if response != nil {
if err := s .writeResponse (response , writer ); err != nil {
return fmt .Errorf ("failed to write response: %w" , err )
}
}
return nil
}
func (s *StdioServer ) handleSamplingResponse (rawMessage json .RawMessage ) bool {
return stdioSessionInstance .handleSamplingResponse (rawMessage )
}
func (s *stdioSession ) handleSamplingResponse (rawMessage json .RawMessage ) bool {
var response struct {
JSONRPC string `json:"jsonrpc"`
ID json .Number `json:"id"`
Result json .RawMessage `json:"result,omitempty"`
Error *struct {
Code int `json:"code"`
Message string `json:"message"`
} `json:"error,omitempty"`
}
if err := json .Unmarshal (rawMessage , &response ); err != nil {
return false
}
idInt64 , err := response .ID .Int64 ()
if err != nil || (response .Result == nil && response .Error == nil ) {
return false
}
s .pendingMu .RLock ()
responseChan , exists := s .pendingRequests [idInt64 ]
s .pendingMu .RUnlock ()
if !exists {
return false
}
samplingResp := &samplingResponse {}
if response .Error != nil {
samplingResp .err = fmt .Errorf ("sampling request failed: %s" , response .Error .Message )
} else {
var result mcp .CreateMessageResult
if err := json .Unmarshal (response .Result , &result ); err != nil {
samplingResp .err = fmt .Errorf ("failed to unmarshal sampling response: %w" , err )
} else {
if contentMap , ok := result .Content .(map [string ]any ); ok {
content , err := mcp .ParseContent (contentMap )
if err != nil {
samplingResp .err = fmt .Errorf ("failed to parse sampling response content: %w" , err )
} else {
result .Content = content
samplingResp .result = &result
}
} else {
samplingResp .result = &result
}
}
}
select {
case responseChan <- samplingResp :
default :
}
return true
}
func (s *StdioServer ) handleElicitationResponse (rawMessage json .RawMessage ) bool {
return stdioSessionInstance .handleElicitationResponse (rawMessage )
}
func (s *stdioSession ) handleElicitationResponse (rawMessage json .RawMessage ) bool {
var response struct {
JSONRPC string `json:"jsonrpc"`
ID json .Number `json:"id"`
Result json .RawMessage `json:"result,omitempty"`
Error *struct {
Code int `json:"code"`
Message string `json:"message"`
} `json:"error,omitempty"`
}
if err := json .Unmarshal (rawMessage , &response ); err != nil {
return false
}
id , err := response .ID .Int64 ()
if err != nil || (response .Result == nil && response .Error == nil ) {
return false
}
s .pendingMu .RLock ()
responseChan , exists := s .pendingElicitations [id ]
s .pendingMu .RUnlock ()
if !exists {
return false
}
elicitationResp := &elicitationResponse {}
if response .Error != nil {
elicitationResp .err = fmt .Errorf ("elicitation request failed: %s" , response .Error .Message )
} else {
var result mcp .ElicitationResult
if err := json .Unmarshal (response .Result , &result ); err != nil {
elicitationResp .err = fmt .Errorf ("failed to unmarshal elicitation response: %w" , err )
} else {
elicitationResp .result = &result
}
}
select {
case responseChan <- elicitationResp :
default :
}
return true
}
func (s *StdioServer ) handleListRootsResponse (rawMessage json .RawMessage ) bool {
return stdioSessionInstance .handleListRootsResponse (rawMessage )
}
func (s *stdioSession ) handleListRootsResponse (rawMessage json .RawMessage ) bool {
var response struct {
JSONRPC string `json:"jsonrpc"`
ID json .Number `json:"id"`
Result json .RawMessage `json:"result,omitempty"`
Error *struct {
Code int `json:"code"`
Message string `json:"message"`
} `json:"error,omitempty"`
}
if err := json .Unmarshal (rawMessage , &response ); err != nil {
return false
}
id , err := response .ID .Int64 ()
if err != nil || (response .Result == nil && response .Error == nil ) {
return false
}
s .pendingMu .RLock ()
responseChan , exists := s .pendingRoots [id ]
s .pendingMu .RUnlock ()
if !exists {
return false
}
rootsResp := &rootsResponse {}
if response .Error != nil {
rootsResp .err = fmt .Errorf ("list root request failed: %s" , response .Error .Message )
} else {
var result mcp .ListRootsResult
if err := json .Unmarshal (response .Result , &result ); err != nil {
rootsResp .err = fmt .Errorf ("failed to unmarshal list root response: %w" , err )
} else {
rootsResp .result = &result
}
}
select {
case responseChan <- rootsResp :
default :
}
return true
}
func (s *StdioServer ) writeResponse (
response mcp .JSONRPCMessage ,
writer io .Writer ,
) error {
responseBytes , err := json .Marshal (response )
if err != nil {
return err
}
s .writeMu .Lock ()
defer s .writeMu .Unlock ()
if _ , err := fmt .Fprintf (writer , "%s\n" , responseBytes ); err != nil {
return err
}
return nil
}
func ServeStdio (server *MCPServer , opts ...StdioOption ) error {
s := NewStdioServer (server )
for _ , opt := range opts {
opt (s )
}
ctx , cancel := context .WithCancel (context .Background ())
defer cancel ()
sigChan := make (chan os .Signal , 1 )
signal .Notify (sigChan , syscall .SIGTERM , syscall .SIGINT )
go func () {
<-sigChan
cancel ()
}()
return s .Listen (ctx , os .Stdin , os .Stdout )
}
The pages are generated with Golds v0.8.4 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .