package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"maps"
"mime"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"unicode"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/util"
)
type StreamableHTTPOption func (*StreamableHTTPServer )
func WithEndpointPath (endpointPath string ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
normalizedPath := "/" + strings .Trim (endpointPath , "/" )
s .endpointPath = normalizedPath
}
}
func WithStateLess (stateLess bool ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
if stateLess {
s .sessionIdManagerResolver = NewDefaultSessionIdManagerResolver (&StatelessSessionIdManager {})
}
}
}
func WithSessionIdManager (manager SessionIdManager ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
if manager == nil {
s .sessionIdManagerResolver = NewDefaultSessionIdManagerResolver (&StatelessSessionIdManager {})
return
}
s .sessionIdManagerResolver = NewDefaultSessionIdManagerResolver (manager )
}
}
func WithSessionIdManagerResolver (resolver SessionIdManagerResolver ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
if resolver == nil {
s .sessionIdManagerResolver = NewDefaultSessionIdManagerResolver (&StatelessSessionIdManager {})
return
}
s .sessionIdManagerResolver = resolver
}
}
func WithStateful (stateful bool ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
if stateful {
s .sessionIdManagerResolver = NewDefaultSessionIdManagerResolver (&InsecureStatefulSessionIdManager {})
}
}
}
func WithHeartbeatInterval (interval time .Duration ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
s .listenHeartbeatInterval = interval
}
}
func WithDisableStreaming (disable bool ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
s .disableStreaming = disable
}
}
func WithHTTPContextFunc (fn HTTPContextFunc ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
s .contextFunc = fn
}
}
func WithStreamableHTTPServer (srv *http .Server ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
s .httpServer = srv
}
}
func WithLogger (logger util .Logger ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
s .logger = logger
}
}
func WithTLSCert (certFile , keyFile string ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
s .tlsCertFile = certFile
s .tlsKeyFile = keyFile
}
}
func WithSessionIdleTTL (ttl time .Duration ) StreamableHTTPOption {
return func (s *StreamableHTTPServer ) {
s .sessionIdleTTL = ttl
}
}
type StreamableHTTPServer struct {
server *MCPServer
sessionTools *sessionToolsStore
sessionResources *sessionResourcesStore
sessionResourceTemplates *sessionResourceTemplatesStore
sessionRequestIDs sync .Map
activeSessions sync .Map
httpServer *http .Server
mu sync .RWMutex
endpointPath string
contextFunc HTTPContextFunc
sessionIdManagerResolver SessionIdManagerResolver
sessionIdManager SessionIdManager
listenHeartbeatInterval time .Duration
logger util .Logger
sessionLogLevels *sessionLogLevelsStore
disableStreaming bool
tlsCertFile string
tlsKeyFile string
sessionIdleTTL time .Duration
sessionLastActive sync .Map
sweeperCancel context .CancelFunc
}
func NewStreamableHTTPServer (server *MCPServer , opts ...StreamableHTTPOption ) *StreamableHTTPServer {
s := &StreamableHTTPServer {
server : server ,
sessionTools : newSessionToolsStore (),
sessionLogLevels : newSessionLogLevelsStore (),
endpointPath : "/mcp" ,
sessionIdManagerResolver : NewDefaultSessionIdManagerResolver (&StatelessGeneratingSessionIdManager {}),
logger : util .DefaultLogger (),
sessionResources : newSessionResourcesStore (),
sessionResourceTemplates : newSessionResourceTemplatesStore (),
}
for _ , opt := range opts {
opt (s )
}
if r , ok := s .sessionIdManagerResolver .(*DefaultSessionIdManagerResolver ); ok {
s .sessionIdManager = r .manager
}
if s .sessionIdleTTL > 0 {
ctx , cancel := context .WithCancel (context .Background ())
s .sweeperCancel = cancel
s .startSessionSweeper (ctx )
}
return s
}
func (s *StreamableHTTPServer ) ServeHTTP (w http .ResponseWriter , r *http .Request ) {
switch r .Method {
case http .MethodPost :
s .handlePost (w , r )
case http .MethodGet :
s .handleGet (w , r )
case http .MethodDelete :
s .handleDelete (w , r )
default :
http .NotFound (w , r )
}
}
func (s *StreamableHTTPServer ) Start (addr string ) error {
s .mu .Lock ()
if s .httpServer == nil {
mux := http .NewServeMux ()
mux .Handle (s .endpointPath , s )
s .httpServer = &http .Server {
Addr : addr ,
Handler : mux ,
}
} else {
if s .httpServer .Addr == "" {
s .httpServer .Addr = addr
} else if s .httpServer .Addr != addr {
return fmt .Errorf ("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)" , s .httpServer .Addr , addr )
}
}
srv := s .httpServer
s .mu .Unlock ()
if s .tlsCertFile != "" || s .tlsKeyFile != "" {
if s .tlsCertFile == "" || s .tlsKeyFile == "" {
return fmt .Errorf ("both TLS cert and key must be provided" )
}
if _ , err := os .Stat (s .tlsCertFile ); err != nil {
return fmt .Errorf ("failed to find TLS certificate file: %w" , err )
}
if _ , err := os .Stat (s .tlsKeyFile ); err != nil {
return fmt .Errorf ("failed to find TLS key file: %w" , err )
}
return srv .ListenAndServeTLS (s .tlsCertFile , s .tlsKeyFile )
}
return srv .ListenAndServe ()
}
func (s *StreamableHTTPServer ) Shutdown (ctx context .Context ) error {
if s .sweeperCancel != nil {
s .sweeperCancel ()
}
s .mu .RLock ()
srv := s .httpServer
s .mu .RUnlock ()
if srv != nil {
return srv .Shutdown (ctx )
}
return nil
}
func (s *StreamableHTTPServer ) handlePost (w http .ResponseWriter , r *http .Request ) {
contentType := r .Header .Get ("Content-Type" )
mediaType , _ , err := mime .ParseMediaType (contentType )
if err != nil || mediaType != "application/json" {
http .Error (w , "Invalid content type: must be 'application/json'" , http .StatusBadRequest )
return
}
rawData , err := io .ReadAll (r .Body )
if err != nil {
s .writeJSONRPCError (w , nil , mcp .PARSE_ERROR , fmt .Sprintf ("read request body error: %v" , err ))
return
}
var jsonMessage struct {
ID json .RawMessage `json:"id"`
Result json .RawMessage `json:"result,omitempty"`
Error json .RawMessage `json:"error,omitempty"`
Method mcp .MCPMethod `json:"method,omitempty"`
}
if err := json .Unmarshal (rawData , &jsonMessage ); err != nil {
s .writeJSONRPCError (w , nil , mcp .PARSE_ERROR , "request body is not valid json" )
return
}
isEmptyResponse := jsonMessage .Method == "" && jsonMessage .ID != nil &&
(isJSONEmpty (jsonMessage .Result ) && isJSONEmpty (jsonMessage .Error ))
isPingResponse := jsonMessage .Method == "" && jsonMessage .ID != nil &&
isExplicitEmptyObject (jsonMessage .Result ) && len (bytes .TrimSpace (jsonMessage .Error )) == 0
if isPingResponse {
w .WriteHeader (http .StatusAccepted )
return
}
if isEmptyResponse {
return
}
isSamplingResponse := jsonMessage .Method == "" && jsonMessage .ID != nil &&
(jsonMessage .Result != nil || jsonMessage .Error != nil )
isInitializeRequest := jsonMessage .Method == mcp .MethodInitialize
if isSamplingResponse {
if err := s .handleSamplingResponse (w , r , jsonMessage ); err != nil {
s .logger .Errorf ("Failed to handle sampling response: %v" , err )
http .Error (w , "Failed to handle sampling response" , http .StatusInternalServerError )
}
return
}
var sessionID string
sessionIdManager := s .sessionIdManagerResolver .ResolveSessionIdManager (r )
if isInitializeRequest {
sessionID = sessionIdManager .Generate ()
} else {
sessionID = r .Header .Get (HeaderKeySessionID )
isTerminated , err := sessionIdManager .Validate (sessionID )
if err != nil {
http .Error (w , "Invalid session ID" , http .StatusNotFound )
return
}
if isTerminated {
http .Error (w , "Session terminated" , http .StatusNotFound )
return
}
}
s .touchSession (sessionID )
var session *streamableHttpSession
if !isInitializeRequest {
if sessionValue , ok := s .server .sessions .Load (sessionID ); ok {
if existingSession , ok := sessionValue .(*streamableHttpSession ); ok {
session = existingSession
}
}
}
if session == nil {
if sessionInterface , exists := s .activeSessions .Load (sessionID ); exists {
if persistentSession , ok := sessionInterface .(*streamableHttpSession ); ok {
session = persistentSession
}
}
}
if session == nil {
session = newStreamableHttpSession (sessionID , s .sessionTools , s .sessionResources , s .sessionResourceTemplates , s .sessionLogLevels )
}
ctx := s .server .WithContext (r .Context (), session )
if s .contextFunc != nil {
ctx = s .contextFunc (ctx , r )
}
mu := sync .Mutex {}
upgradedHeader := false
done := make (chan struct {})
ctx = context .WithValue (ctx , requestHeader , r .Header )
go func () {
for {
select {
case nt := <- session .notificationChannel :
func () {
mu .Lock ()
defer mu .Unlock ()
select {
case <- done :
return
default :
}
defer func () {
flusher , ok := w .(http .Flusher )
if ok {
flusher .Flush ()
}
}()
if !upgradedHeader {
w .Header ().Set ("Content-Type" , "text/event-stream" )
w .Header ().Set ("Connection" , "keep-alive" )
w .Header ().Set ("Cache-Control" , "no-cache" )
w .WriteHeader (http .StatusOK )
upgradedHeader = true
}
err := writeSSEEvent (w , nt )
if err != nil {
s .logger .Errorf ("Failed to write SSE event: %v" , err )
return
}
}()
case <- done :
return
case <- ctx .Done ():
return
}
}
}()
response := s .server .HandleMessage (ctx , rawData )
if response == nil {
mu .Lock ()
close (done )
if !upgradedHeader {
mu .Unlock ()
w .WriteHeader (http .StatusAccepted )
} else {
mu .Unlock ()
}
return
}
mu .Lock ()
drainLoop :
for {
select {
case nt := <- session .notificationChannel :
if !upgradedHeader {
w .Header ().Set ("Content-Type" , "text/event-stream" )
w .Header ().Set ("Connection" , "keep-alive" )
w .Header ().Set ("Cache-Control" , "no-cache" )
w .WriteHeader (http .StatusOK )
upgradedHeader = true
}
if err := writeSSEEvent (w , nt ); err != nil {
s .logger .Errorf ("Failed to write SSE event during drain: %v" , err )
}
if flusher , ok := w .(http .Flusher ); ok {
flusher .Flush ()
}
default :
break drainLoop
}
}
close (done )
mu .Unlock ()
if ctx .Err () != nil {
return
}
if session .upgradeToSSE .Load () || upgradedHeader {
if !upgradedHeader {
w .Header ().Set ("Content-Type" , "text/event-stream" )
w .Header ().Set ("Connection" , "keep-alive" )
w .Header ().Set ("Cache-Control" , "no-cache" )
w .WriteHeader (http .StatusOK )
upgradedHeader = true
}
if err := writeSSEEvent (w , response ); err != nil {
s .logger .Errorf ("Failed to write final SSE response event: %v" , err )
}
} else {
w .Header ().Set ("Content-Type" , "application/json" )
if isInitializeRequest && sessionID != "" {
w .Header ().Set (HeaderKeySessionID , sessionID )
}
w .WriteHeader (http .StatusOK )
err := json .NewEncoder (w ).Encode (response )
if err != nil {
s .logger .Errorf ("Failed to write response: %v" , err )
}
}
if isInitializeRequest && sessionID != "" {
if _ , exists := s .server .sessions .Load (sessionID ); !exists {
s .activeSessions .Store (sessionID , session )
if err := s .server .RegisterSession (ctx , session ); err != nil {
s .logger .Errorf ("Failed to register POST session: %v" , err )
s .activeSessions .Delete (sessionID )
}
}
}
}
func (s *StreamableHTTPServer ) handleGet (w http .ResponseWriter , r *http .Request ) {
if s .disableStreaming {
s .logger .Infof ("Rejected GET request: streaming is disabled (session: %s)" , r .Header .Get (HeaderKeySessionID ))
http .Error (w , "Streaming is disabled on this server" , http .StatusMethodNotAllowed )
return
}
flusher , ok := w .(http .Flusher )
if !ok {
http .Error (w , "Streaming unsupported" , http .StatusMethodNotAllowed )
return
}
sessionID := r .Header .Get (HeaderKeySessionID )
if sessionID == "" {
sessionIdManager := s .sessionIdManagerResolver .ResolveSessionIdManager (r )
sessionID = sessionIdManager .Generate ()
}
var session *streamableHttpSession
newSession := newStreamableHttpSession (sessionID , s .sessionTools , s .sessionResources , s .sessionResourceTemplates , s .sessionLogLevels )
actual , loaded := s .activeSessions .LoadOrStore (sessionID , newSession )
session = actual .(*streamableHttpSession )
if !loaded {
if err := s .server .RegisterSession (r .Context (), session ); err != nil {
s .activeSessions .Delete (sessionID )
http .Error (w , fmt .Sprintf ("Session registration failed: %v" , err ), http .StatusBadRequest )
return
}
defer s .server .UnregisterSession (r .Context (), sessionID )
defer s .activeSessions .Delete (sessionID )
}
s .touchSession (sessionID )
w .Header ().Set ("Content-Type" , "text/event-stream" )
w .Header ().Set ("Cache-Control" , "no-cache" )
w .Header ().Set ("Connection" , "keep-alive" )
w .WriteHeader (http .StatusOK )
flusher .Flush ()
done := make (chan struct {})
defer close (done )
writeChan := make (chan any , 16 )
go func () {
for {
select {
case nt := <- session .notificationChannel :
select {
case writeChan <- &nt :
case <- done :
return
}
case samplingReq := <- session .samplingRequestChan :
jsonrpcRequest := mcp .JSONRPCRequest {
JSONRPC : "2.0" ,
ID : mcp .NewRequestId (samplingReq .requestID ),
Request : mcp .Request {
Method : string (mcp .MethodSamplingCreateMessage ),
},
Params : samplingReq .request .CreateMessageParams ,
}
select {
case writeChan <- jsonrpcRequest :
case <- done :
return
}
case elicitationReq := <- session .elicitationRequestChan :
jsonrpcRequest := mcp .JSONRPCRequest {
JSONRPC : "2.0" ,
ID : mcp .NewRequestId (elicitationReq .requestID ),
Request : mcp .Request {
Method : string (mcp .MethodElicitationCreate ),
},
Params : elicitationReq .request .Params ,
}
select {
case writeChan <- jsonrpcRequest :
case <- done :
return
}
case rootsReq := <- session .rootsRequestChan :
jsonrpcRequest := mcp .JSONRPCRequest {
JSONRPC : "2.0" ,
ID : mcp .NewRequestId (rootsReq .requestID ),
Request : mcp .Request {
Method : string (mcp .MethodListRoots ),
},
}
select {
case writeChan <- jsonrpcRequest :
case <- done :
return
}
case <- done :
return
}
}
}()
if s .listenHeartbeatInterval > 0 {
go func () {
ticker := time .NewTicker (s .listenHeartbeatInterval )
defer ticker .Stop ()
for {
select {
case <- ticker .C :
message := mcp .JSONRPCRequest {
JSONRPC : "2.0" ,
ID : mcp .NewRequestId (s .nextRequestID (sessionID )),
Request : mcp .Request {
Method : "ping" ,
},
}
select {
case writeChan <- message :
case <- done :
return
}
case <- done :
return
}
}
}()
}
for {
select {
case data := <- writeChan :
if data == nil {
continue
}
if err := writeSSEEvent (w , data ); err != nil {
s .logger .Errorf ("Failed to write SSE event: %v" , err )
return
}
flusher .Flush ()
s .touchSession (sessionID )
case <- r .Context ().Done ():
return
}
}
}
func (s *StreamableHTTPServer ) handleDelete (w http .ResponseWriter , r *http .Request ) {
sessionID := r .Header .Get (HeaderKeySessionID )
sessionIdManager := s .sessionIdManagerResolver .ResolveSessionIdManager (r )
notAllowed , err := sessionIdManager .Terminate (sessionID )
if err != nil {
http .Error (w , fmt .Sprintf ("Session termination failed: %v" , err ), http .StatusInternalServerError )
return
}
if notAllowed {
http .Error (w , "Session termination not allowed" , http .StatusMethodNotAllowed )
return
}
s .cleanupSessionState (r .Context (), sessionID )
w .WriteHeader (http .StatusOK )
}
func writeSSEEvent(w io .Writer , data any ) error {
jsonData , err := json .Marshal (data )
if err != nil {
return fmt .Errorf ("failed to marshal data: %w" , err )
}
_, err = fmt .Fprintf (w , "event: message\ndata: %s\n\n" , jsonData )
if err != nil {
return fmt .Errorf ("failed to write SSE event: %w" , err )
}
return nil
}
func (s *StreamableHTTPServer ) handleSamplingResponse (w http .ResponseWriter , r *http .Request , responseMessage struct {
ID json .RawMessage `json:"id"`
Result json .RawMessage `json:"result,omitempty"`
Error json .RawMessage `json:"error,omitempty"`
Method mcp .MCPMethod `json:"method,omitempty"`
}) error {
sessionID := r .Header .Get (HeaderKeySessionID )
if sessionID == "" {
http .Error (w , "Missing session ID for sampling response" , http .StatusBadRequest )
return fmt .Errorf ("missing session ID" )
}
sessionIdManager := s .sessionIdManagerResolver .ResolveSessionIdManager (r )
isTerminated , err := sessionIdManager .Validate (sessionID )
if err != nil {
http .Error (w , "Invalid session ID" , http .StatusNotFound )
return err
}
if isTerminated {
http .Error (w , "Session terminated" , http .StatusNotFound )
return fmt .Errorf ("session terminated" )
}
var requestID int64
if err := json .Unmarshal (responseMessage .ID , &requestID ); err != nil {
http .Error (w , "Invalid request ID in sampling response" , http .StatusBadRequest )
return err
}
response := samplingResponseItem {
requestID : requestID ,
}
if responseMessage .Error != nil {
var jsonrpcError struct {
Code int `json:"code"`
Message string `json:"message"`
}
if err := json .Unmarshal (responseMessage .Error , &jsonrpcError ); err != nil {
response .err = fmt .Errorf ("failed to parse error: %v" , err )
} else {
response .err = fmt .Errorf ("sampling error %d: %s" , jsonrpcError .Code , jsonrpcError .Message )
}
} else if responseMessage .Result != nil {
response .result = responseMessage .Result
} else {
response .err = fmt .Errorf ("sampling response has neither result nor error" )
}
if err := s .deliverSamplingResponse (sessionID , response ); err != nil {
s .logger .Errorf ("Failed to deliver sampling response: %v" , err )
http .Error (w , "Failed to deliver response" , http .StatusInternalServerError )
return err
}
w .WriteHeader (http .StatusAccepted )
return nil
}
func (s *StreamableHTTPServer ) deliverSamplingResponse (sessionID string , response samplingResponseItem ) error {
sessionInterface , ok := s .activeSessions .Load (sessionID )
if !ok {
return fmt .Errorf ("no active session found for session %s" , sessionID )
}
session , ok := sessionInterface .(*streamableHttpSession )
if !ok {
return fmt .Errorf ("invalid session type for session %s" , sessionID )
}
responseChannelInterface , exists := session .samplingRequests .Load (response .requestID )
if !exists {
return fmt .Errorf ("no pending request found for session %s, request %d" , sessionID , response .requestID )
}
responseChan , ok := responseChannelInterface .(chan samplingResponseItem )
if !ok {
return fmt .Errorf ("invalid response channel type for session %s, request %d" , sessionID , response .requestID )
}
select {
case responseChan <- response :
s .logger .Infof ("Delivered sampling response for session %s, request %d" , sessionID , response .requestID )
return nil
default :
return fmt .Errorf ("failed to deliver sampling response for session %s, request %d: channel full or blocked" , sessionID , response .requestID )
}
}
func (s *StreamableHTTPServer ) writeJSONRPCError (
w http .ResponseWriter ,
id any ,
code int ,
message string ,
) {
response := createErrorResponse (id , code , message )
w .Header ().Set ("Content-Type" , "application/json" )
w .WriteHeader (http .StatusBadRequest )
err := json .NewEncoder (w ).Encode (response )
if err != nil {
s .logger .Errorf ("Failed to write JSONRPCError: %v" , err )
}
}
func (s *StreamableHTTPServer ) nextRequestID (sessionID string ) int64 {
actual , _ := s .sessionRequestIDs .LoadOrStore (sessionID , new (atomic .Int64 ))
counter := actual .(*atomic .Int64 )
return counter .Add (1 )
}
func (s *StreamableHTTPServer ) touchSession (sessionID string ) {
if sessionID == "" || s .sessionIdleTTL <= 0 {
return
}
now := time .Now ().UnixNano ()
actual , _ := s .sessionLastActive .LoadOrStore (sessionID , new (atomic .Int64 ))
actual .(*atomic .Int64 ).Store (now )
}
func (s *StreamableHTTPServer ) cleanupSessionState (ctx context .Context , sessionID string ) {
s .server .UnregisterSession (ctx , sessionID )
s .activeSessions .Delete (sessionID )
s .sessionTools .delete (sessionID )
s .sessionResources .delete (sessionID )
s .sessionResourceTemplates .delete (sessionID )
s .sessionLogLevels .delete (sessionID )
s .sessionRequestIDs .Delete (sessionID )
s .sessionLastActive .Delete (sessionID )
}
func (s *StreamableHTTPServer ) startSessionSweeper (ctx context .Context ) {
interval := max (s .sessionIdleTTL /2 , time .Second )
go func () {
ticker := time .NewTicker (interval )
defer ticker .Stop ()
for {
select {
case <- ctx .Done ():
return
case <- ticker .C :
s .sweepExpiredSessions ()
}
}
}()
}
func (s *StreamableHTTPServer ) sweepExpiredSessions () {
now := time .Now ().UnixNano ()
ttlNanos := s .sessionIdleTTL .Nanoseconds ()
s .sessionLastActive .Range (func (key , value any ) bool {
sessionID , ok := key .(string )
if !ok {
s .sessionLastActive .Delete (key )
return true
}
lastActive , ok := value .(*atomic .Int64 )
if !ok {
s .sessionLastActive .Delete (key )
return true
}
capturedLastActive := lastActive .Load ()
if now -capturedLastActive < ttlNanos {
return true
}
if lastActive .Load () != capturedLastActive {
return true
}
s .logger .Infof ("Sweeping expired session: %s" , sessionID )
mgr := s .sessionIdManager
if mgr == nil {
mgr = s .sessionIdManagerResolver .ResolveSessionIdManager (nil )
}
_, _ = mgr .Terminate (sessionID )
s .cleanupSessionState (context .Background (), sessionID )
return true
})
}
type sessionLogLevelsStore struct {
mu sync .RWMutex
logs map [string ]mcp .LoggingLevel
}
func newSessionLogLevelsStore() *sessionLogLevelsStore {
return &sessionLogLevelsStore {
logs : make (map [string ]mcp .LoggingLevel ),
}
}
func (s *sessionLogLevelsStore ) get (sessionID string ) mcp .LoggingLevel {
s .mu .RLock ()
defer s .mu .RUnlock ()
val , ok := s .logs [sessionID ]
if !ok {
return mcp .LoggingLevelError
}
return val
}
func (s *sessionLogLevelsStore ) set (sessionID string , level mcp .LoggingLevel ) {
s .mu .Lock ()
defer s .mu .Unlock ()
s .logs [sessionID ] = level
}
func (s *sessionLogLevelsStore ) delete (sessionID string ) {
s .mu .Lock ()
defer s .mu .Unlock ()
delete (s .logs , sessionID )
}
type sessionResourcesStore struct {
mu sync .RWMutex
resources map [string ]map [string ]ServerResource
}
func newSessionResourcesStore() *sessionResourcesStore {
return &sessionResourcesStore {
resources : make (map [string ]map [string ]ServerResource ),
}
}
func (s *sessionResourcesStore ) get (sessionID string ) map [string ]ServerResource {
s .mu .RLock ()
defer s .mu .RUnlock ()
cloned := make (map [string ]ServerResource , len (s .resources [sessionID ]))
maps .Copy (cloned , s .resources [sessionID ])
return cloned
}
func (s *sessionResourcesStore ) set (sessionID string , resources map [string ]ServerResource ) {
s .mu .Lock ()
defer s .mu .Unlock ()
cloned := make (map [string ]ServerResource , len (resources ))
maps .Copy (cloned , resources )
s .resources [sessionID ] = cloned
}
func (s *sessionResourcesStore ) delete (sessionID string ) {
s .mu .Lock ()
defer s .mu .Unlock ()
delete (s .resources , sessionID )
}
type sessionResourceTemplatesStore struct {
mu sync .RWMutex
templates map [string ]map [string ]ServerResourceTemplate
}
func newSessionResourceTemplatesStore() *sessionResourceTemplatesStore {
return &sessionResourceTemplatesStore {
templates : make (map [string ]map [string ]ServerResourceTemplate ),
}
}
func (s *sessionResourceTemplatesStore ) get (sessionID string ) map [string ]ServerResourceTemplate {
s .mu .RLock ()
defer s .mu .RUnlock ()
cloned := make (map [string ]ServerResourceTemplate , len (s .templates [sessionID ]))
maps .Copy (cloned , s .templates [sessionID ])
return cloned
}
func (s *sessionResourceTemplatesStore ) set (sessionID string , templates map [string ]ServerResourceTemplate ) {
s .mu .Lock ()
defer s .mu .Unlock ()
cloned := make (map [string ]ServerResourceTemplate , len (templates ))
maps .Copy (cloned , templates )
s .templates [sessionID ] = cloned
}
func (s *sessionResourceTemplatesStore ) delete (sessionID string ) {
s .mu .Lock ()
defer s .mu .Unlock ()
delete (s .templates , sessionID )
}
type sessionToolsStore struct {
mu sync .RWMutex
tools map [string ]map [string ]ServerTool
}
func newSessionToolsStore() *sessionToolsStore {
return &sessionToolsStore {
tools : make (map [string ]map [string ]ServerTool ),
}
}
func (s *sessionToolsStore ) get (sessionID string ) map [string ]ServerTool {
s .mu .RLock ()
defer s .mu .RUnlock ()
cloned := make (map [string ]ServerTool , len (s .tools [sessionID ]))
maps .Copy (cloned , s .tools [sessionID ])
return cloned
}
func (s *sessionToolsStore ) set (sessionID string , tools map [string ]ServerTool ) {
s .mu .Lock ()
defer s .mu .Unlock ()
cloned := make (map [string ]ServerTool , len (tools ))
maps .Copy (cloned , tools )
s .tools [sessionID ] = cloned
}
func (s *sessionToolsStore ) delete (sessionID string ) {
s .mu .Lock ()
defer s .mu .Unlock ()
delete (s .tools , sessionID )
}
type samplingRequestItem struct {
requestID int64
request mcp .CreateMessageRequest
response chan samplingResponseItem
}
type samplingResponseItem struct {
requestID int64
result json .RawMessage
err error
}
type elicitationRequestItem struct {
requestID int64
request mcp .ElicitationRequest
response chan samplingResponseItem
}
type rootsRequestItem struct {
requestID int64
request mcp .ListRootsRequest
response chan samplingResponseItem
}
type streamableHttpSession struct {
sessionID string
notificationChannel chan mcp .JSONRPCNotification
tools *sessionToolsStore
resources *sessionResourcesStore
resourceTemplates *sessionResourceTemplatesStore
upgradeToSSE atomic .Bool
logLevels *sessionLogLevelsStore
clientInfo atomic .Value
clientCapabilities atomic .Value
samplingRequestChan chan samplingRequestItem
elicitationRequestChan chan elicitationRequestItem
rootsRequestChan chan rootsRequestItem
samplingRequests sync .Map
requestIDCounter atomic .Int64
}
func newStreamableHttpSession(sessionID string , toolStore *sessionToolsStore , resourcesStore *sessionResourcesStore , templatesStore *sessionResourceTemplatesStore , levels *sessionLogLevelsStore ) *streamableHttpSession {
s := &streamableHttpSession {
sessionID : sessionID ,
notificationChannel : make (chan mcp .JSONRPCNotification , 100 ),
tools : toolStore ,
resources : resourcesStore ,
resourceTemplates : templatesStore ,
logLevels : levels ,
samplingRequestChan : make (chan samplingRequestItem , 10 ),
elicitationRequestChan : make (chan elicitationRequestItem , 10 ),
rootsRequestChan : make (chan rootsRequestItem , 10 ),
}
return s
}
func (s *streamableHttpSession ) SessionID () string {
return s .sessionID
}
func (s *streamableHttpSession ) NotificationChannel () chan <- mcp .JSONRPCNotification {
return s .notificationChannel
}
func (s *streamableHttpSession ) Initialize () {
}
func (s *streamableHttpSession ) Initialized () bool {
return true
}
func (s *streamableHttpSession ) SetLogLevel (level mcp .LoggingLevel ) {
s .logLevels .set (s .sessionID , level )
}
func (s *streamableHttpSession ) GetLogLevel () mcp .LoggingLevel {
return s .logLevels .get (s .sessionID )
}
var _ ClientSession = (*streamableHttpSession )(nil )
func (s *streamableHttpSession ) GetSessionTools () map [string ]ServerTool {
return s .tools .get (s .sessionID )
}
func (s *streamableHttpSession ) SetSessionTools (tools map [string ]ServerTool ) {
s .tools .set (s .sessionID , tools )
}
func (s *streamableHttpSession ) GetSessionResources () map [string ]ServerResource {
return s .resources .get (s .sessionID )
}
func (s *streamableHttpSession ) SetSessionResources (resources map [string ]ServerResource ) {
s .resources .set (s .sessionID , resources )
}
func (s *streamableHttpSession ) GetSessionResourceTemplates () map [string ]ServerResourceTemplate {
return s .resourceTemplates .get (s .sessionID )
}
func (s *streamableHttpSession ) SetSessionResourceTemplates (templates map [string ]ServerResourceTemplate ) {
s .resourceTemplates .set (s .sessionID , templates )
}
func (s *streamableHttpSession ) GetClientInfo () mcp .Implementation {
if value := s .clientInfo .Load (); value != nil {
if clientInfo , ok := value .(mcp .Implementation ); ok {
return clientInfo
}
}
return mcp .Implementation {}
}
func (s *streamableHttpSession ) SetClientInfo (clientInfo mcp .Implementation ) {
s .clientInfo .Store (clientInfo )
}
func (s *streamableHttpSession ) GetClientCapabilities () mcp .ClientCapabilities {
if value := s .clientCapabilities .Load (); value != nil {
if clientCapabilities , ok := value .(mcp .ClientCapabilities ); ok {
return clientCapabilities
}
}
return mcp .ClientCapabilities {}
}
func (s *streamableHttpSession ) SetClientCapabilities (clientCapabilities mcp .ClientCapabilities ) {
s .clientCapabilities .Store (clientCapabilities )
}
var (
_ SessionWithTools = (*streamableHttpSession )(nil )
_ SessionWithResources = (*streamableHttpSession )(nil )
_ SessionWithResourceTemplates = (*streamableHttpSession )(nil )
_ SessionWithLogging = (*streamableHttpSession )(nil )
_ SessionWithClientInfo = (*streamableHttpSession )(nil )
)
func (s *streamableHttpSession ) UpgradeToSSEWhenReceiveNotification () {
s .upgradeToSSE .Store (true )
}
var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession )(nil )
func (s *streamableHttpSession ) RequestSampling (ctx context .Context , request mcp .CreateMessageRequest ) (*mcp .CreateMessageResult , error ) {
requestID := s .requestIDCounter .Add (1 )
responseChan := make (chan samplingResponseItem , 1 )
samplingRequest := samplingRequestItem {
requestID : requestID ,
request : request ,
response : responseChan ,
}
s .samplingRequests .Store (requestID , responseChan )
defer s .samplingRequests .Delete (requestID )
select {
case s .samplingRequestChan <- samplingRequest :
case <- ctx .Done ():
return nil , ctx .Err ()
default :
return nil , fmt .Errorf ("sampling request queue is full - server overloaded" )
}
select {
case response := <- responseChan :
if response .err != nil {
return nil , response .err
}
var result mcp .CreateMessageResult
if err := json .Unmarshal (response .result , &result ); err != nil {
return nil , fmt .Errorf ("failed to unmarshal sampling response: %v" , err )
}
if contentMap , ok := result .Content .(map [string ]any ); ok {
content , err := mcp .ParseContent (contentMap )
if err != nil {
return nil , fmt .Errorf ("failed to parse sampling response content: %w" , err )
}
result .Content = content
}
return &result , nil
case <- ctx .Done ():
return nil , ctx .Err ()
}
}
func (s *streamableHttpSession ) ListRoots (ctx context .Context , request mcp .ListRootsRequest ) (*mcp .ListRootsResult , error ) {
requestID := s .requestIDCounter .Add (1 )
responseChan := make (chan samplingResponseItem , 1 )
rootsRequest := rootsRequestItem {
requestID : requestID ,
request : request ,
response : responseChan ,
}
s .samplingRequests .Store (requestID , responseChan )
defer s .samplingRequests .Delete (requestID )
select {
case s .rootsRequestChan <- rootsRequest :
case <- ctx .Done ():
return nil , ctx .Err ()
default :
return nil , fmt .Errorf ("list roots request queue is full - server overloaded" )
}
select {
case response := <- responseChan :
if response .err != nil {
return nil , response .err
}
var result mcp .ListRootsResult
if err := json .Unmarshal (response .result , &result ); err != nil {
return nil , fmt .Errorf ("failed to unmarshal list roots response: %v" , err )
}
return &result , nil
case <- ctx .Done ():
return nil , ctx .Err ()
}
}
func (s *streamableHttpSession ) RequestElicitation (ctx context .Context , request mcp .ElicitationRequest ) (*mcp .ElicitationResult , error ) {
requestID := s .requestIDCounter .Add (1 )
responseChan := make (chan samplingResponseItem , 1 )
elicitationRequest := elicitationRequestItem {
requestID : requestID ,
request : request ,
response : responseChan ,
}
s .samplingRequests .Store (requestID , responseChan )
defer s .samplingRequests .Delete (requestID )
select {
case s .elicitationRequestChan <- elicitationRequest :
case <- ctx .Done ():
return nil , ctx .Err ()
default :
return nil , fmt .Errorf ("elicitation request queue is full - server overloaded" )
}
select {
case response := <- responseChan :
if response .err != nil {
return nil , response .err
}
var result mcp .ElicitationResult
if err := json .Unmarshal (response .result , &result ); err != nil {
return nil , fmt .Errorf ("failed to unmarshal elicitation response: %v" , err )
}
return &result , nil
case <- ctx .Done ():
return nil , ctx .Err ()
}
}
var _ SessionWithSampling = (*streamableHttpSession )(nil )
var _ SessionWithElicitation = (*streamableHttpSession )(nil )
var _ SessionWithRoots = (*streamableHttpSession )(nil )
type SessionIdManagerResolver interface {
ResolveSessionIdManager (r *http .Request ) SessionIdManager
}
type SessionIdManager interface {
Generate () string
Validate (sessionID string ) (isTerminated bool , err error )
Terminate (sessionID string ) (isNotAllowed bool , err error )
}
type DefaultSessionIdManagerResolver struct {
manager SessionIdManager
}
func NewDefaultSessionIdManagerResolver (manager SessionIdManager ) *DefaultSessionIdManagerResolver {
if manager == nil {
manager = &StatelessSessionIdManager {}
}
return &DefaultSessionIdManagerResolver {manager : manager }
}
func (r *DefaultSessionIdManagerResolver ) ResolveSessionIdManager (_ *http .Request ) SessionIdManager {
return r .manager
}
type StatelessSessionIdManager struct {}
func (s *StatelessSessionIdManager ) Generate () string {
return ""
}
func (s *StatelessSessionIdManager ) Validate (sessionID string ) (isTerminated bool , err error ) {
return false , nil
}
func (s *StatelessSessionIdManager ) Terminate (sessionID string ) (isNotAllowed bool , err error ) {
return false , nil
}
type StatelessGeneratingSessionIdManager struct {}
func (s *StatelessGeneratingSessionIdManager ) Generate () string {
return idPrefix + uuid .New ().String ()
}
func (s *StatelessGeneratingSessionIdManager ) Validate (sessionID string ) (isTerminated bool , err error ) {
if !strings .HasPrefix (sessionID , idPrefix ) {
return false , fmt .Errorf ("invalid session id: %s" , sessionID )
}
if _ , err := uuid .Parse (sessionID [len (idPrefix ):]); err != nil {
return false , fmt .Errorf ("invalid session id: %s" , sessionID )
}
return false , nil
}
func (s *StatelessGeneratingSessionIdManager ) Terminate (sessionID string ) (isNotAllowed bool , err error ) {
return false , nil
}
type InsecureStatefulSessionIdManager struct {
sessions sync .Map
terminated sync .Map
}
const idPrefix = "mcp-session-"
func (s *InsecureStatefulSessionIdManager ) Generate () string {
sessionID := idPrefix + uuid .New ().String ()
s .sessions .Store (sessionID , true )
return sessionID
}
func (s *InsecureStatefulSessionIdManager ) Validate (sessionID string ) (isTerminated bool , err error ) {
if !strings .HasPrefix (sessionID , idPrefix ) {
return false , fmt .Errorf ("invalid session id: %s" , sessionID )
}
if _ , err := uuid .Parse (sessionID [len (idPrefix ):]); err != nil {
return false , fmt .Errorf ("invalid session id: %s" , sessionID )
}
if _ , exists := s .terminated .Load (sessionID ); exists {
return true , nil
}
if _ , exists := s .sessions .Load (sessionID ); !exists {
return false , fmt .Errorf ("session not found: %s" , sessionID )
}
return false , nil
}
func (s *InsecureStatefulSessionIdManager ) Terminate (sessionID string ) (isNotAllowed bool , err error ) {
if _ , exists := s .terminated .Load (sessionID ); exists {
return false , nil
}
if _ , exists := s .sessions .Load (sessionID ); !exists {
return false , nil
}
s .terminated .Store (sessionID , true )
s .sessions .Delete (sessionID )
return false , nil
}
func NewTestStreamableHTTPServer (server *MCPServer , opts ...StreamableHTTPOption ) *httptest .Server {
sseServer := NewStreamableHTTPServer (server , opts ...)
testServer := httptest .NewServer (sseServer )
return testServer
}
func isJSONEmpty(data json .RawMessage ) bool {
if len (data ) == 0 {
return true
}
trimmed := bytes .TrimSpace (data )
if len (trimmed ) == 0 {
return true
}
switch trimmed [0 ] {
case '{' :
if len (trimmed ) == 2 && trimmed [1 ] == '}' {
return true
}
for i := 1 ; i < len (trimmed ); i ++ {
if !unicode .IsSpace (rune (trimmed [i ])) {
return trimmed [i ] == '}'
}
}
case '[' :
if len (trimmed ) == 2 && trimmed [1 ] == ']' {
return true
}
for i := 1 ; i < len (trimmed ); i ++ {
if !unicode .IsSpace (rune (trimmed [i ])) {
return trimmed [i ] == ']'
}
}
case '"' :
return false
case 'n' :
return len (trimmed ) == 4 &&
trimmed [1 ] == 'u' &&
trimmed [2 ] == 'l' &&
trimmed [3 ] == 'l'
}
return false
}
func isExplicitEmptyObject(data json .RawMessage ) bool {
if len (data ) == 0 {
return false
}
trimmed := bytes .TrimSpace (data )
if len (trimmed ) == 0 || trimmed [0 ] != '{' {
return false
}
var obj map [string ]json .RawMessage
if err := json .Unmarshal (trimmed , &obj ); err != nil {
return false
}
return len (obj ) == 0
}
The pages are generated with Golds v0.8.4 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .