package server
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/http/httptest"
"net/url"
"path"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
)
type sseSession struct {
done chan struct {}
eventQueue chan string
sessionID string
requestID atomic .Int64
notificationChannel chan mcp .JSONRPCNotification
initialized atomic .Bool
loggingLevel atomic .Value
tools sync .Map
resources sync .Map
resourceTemplates sync .Map
clientInfo atomic .Value
clientCapabilities atomic .Value
}
type SSEContextFunc func (ctx context .Context , r *http .Request ) context .Context
type DynamicBasePathFunc func (r *http .Request , sessionID string ) string
type SessionIDGenFunc func (ctx context .Context , r *http .Request ) (string , error )
func (s *sseSession ) SessionID () string {
return s .sessionID
}
func (s *sseSession ) NotificationChannel () chan <- mcp .JSONRPCNotification {
return s .notificationChannel
}
func (s *sseSession ) Initialize () {
s .loggingLevel .Store (mcp .LoggingLevelError )
s .initialized .Store (true )
}
func (s *sseSession ) Initialized () bool {
return s .initialized .Load ()
}
func (s *sseSession ) SetLogLevel (level mcp .LoggingLevel ) {
s .loggingLevel .Store (level )
}
func (s *sseSession ) GetLogLevel () mcp .LoggingLevel {
level := s .loggingLevel .Load ()
if level == nil {
return mcp .LoggingLevelError
}
return level .(mcp .LoggingLevel )
}
func (s *sseSession ) GetSessionResources () map [string ]ServerResource {
resources := make (map [string ]ServerResource )
s .resources .Range (func (key , value any ) bool {
if resource , ok := value .(ServerResource ); ok {
resources [key .(string )] = resource
}
return true
})
return resources
}
func (s *sseSession ) SetSessionResources (resources map [string ]ServerResource ) {
s .resources .Clear ()
for name , resource := range resources {
s .resources .Store (name , resource )
}
}
func (s *sseSession ) GetSessionResourceTemplates () map [string ]ServerResourceTemplate {
templates := make (map [string ]ServerResourceTemplate )
s .resourceTemplates .Range (func (key , value any ) bool {
if template , ok := value .(ServerResourceTemplate ); ok {
templates [key .(string )] = template
}
return true
})
return templates
}
func (s *sseSession ) SetSessionResourceTemplates (templates map [string ]ServerResourceTemplate ) {
s .resourceTemplates .Clear ()
for uriTemplate , template := range templates {
s .resourceTemplates .Store (uriTemplate , template )
}
}
func (s *sseSession ) GetSessionTools () map [string ]ServerTool {
tools := make (map [string ]ServerTool )
s .tools .Range (func (key , value any ) bool {
if tool , ok := value .(ServerTool ); ok {
tools [key .(string )] = tool
}
return true
})
return tools
}
func (s *sseSession ) SetSessionTools (tools map [string ]ServerTool ) {
s .tools .Clear ()
for name , tool := range tools {
s .tools .Store (name , tool )
}
}
func (s *sseSession ) GetClientInfo () mcp .Implementation {
if value := s .clientInfo .Load (); value != nil {
if clientInfo , ok := value .(mcp .Implementation ); ok {
return clientInfo
}
}
return mcp .Implementation {}
}
func (s *sseSession ) SetClientInfo (clientInfo mcp .Implementation ) {
s .clientInfo .Store (clientInfo )
}
func (s *sseSession ) SetClientCapabilities (clientCapabilities mcp .ClientCapabilities ) {
s .clientCapabilities .Store (clientCapabilities )
}
func (s *sseSession ) GetClientCapabilities () mcp .ClientCapabilities {
if value := s .clientCapabilities .Load (); value != nil {
if clientCapabilities , ok := value .(mcp .ClientCapabilities ); ok {
return clientCapabilities
}
}
return mcp .ClientCapabilities {}
}
var (
_ ClientSession = (*sseSession )(nil )
_ SessionWithTools = (*sseSession )(nil )
_ SessionWithResources = (*sseSession )(nil )
_ SessionWithResourceTemplates = (*sseSession )(nil )
_ SessionWithLogging = (*sseSession )(nil )
_ SessionWithClientInfo = (*sseSession )(nil )
)
type SSEServer struct {
server *MCPServer
baseURL string
basePath string
appendQueryToMessageEndpoint bool
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
sessions sync .Map
srv *http .Server
contextFunc SSEContextFunc
dynamicBasePathFunc DynamicBasePathFunc
sessionIDGenFunc SessionIDGenFunc
keepAlive bool
keepAliveInterval time .Duration
mu sync .RWMutex
}
type SSEOption func (*SSEServer )
func WithBaseURL (baseURL string ) SSEOption {
return func (s *SSEServer ) {
if baseURL != "" {
u , err := url .Parse (baseURL )
if err != nil {
return
}
if u .Scheme != "http" && u .Scheme != "https" {
return
}
if u .Host == "" || strings .HasPrefix (u .Host , ":" ) {
return
}
if len (u .Query ()) > 0 {
return
}
}
s .baseURL = strings .TrimSuffix (baseURL , "/" )
}
}
func WithStaticBasePath (basePath string ) SSEOption {
return func (s *SSEServer ) {
s .basePath = normalizeURLPath (basePath )
}
}
func WithBasePath (basePath string ) SSEOption {
return WithStaticBasePath (basePath )
}
func WithDynamicBasePath (fn DynamicBasePathFunc ) SSEOption {
return func (s *SSEServer ) {
if fn != nil {
s .dynamicBasePathFunc = func (r *http .Request , sid string ) string {
bp := fn (r , sid )
return normalizeURLPath (bp )
}
}
}
}
func WithMessageEndpoint (endpoint string ) SSEOption {
return func (s *SSEServer ) {
s .messageEndpoint = endpoint
}
}
func WithAppendQueryToMessageEndpoint () SSEOption {
return func (s *SSEServer ) {
s .appendQueryToMessageEndpoint = true
}
}
func WithUseFullURLForMessageEndpoint (useFullURLForMessageEndpoint bool ) SSEOption {
return func (s *SSEServer ) {
s .useFullURLForMessageEndpoint = useFullURLForMessageEndpoint
}
}
func WithSSEEndpoint (endpoint string ) SSEOption {
return func (s *SSEServer ) {
s .sseEndpoint = endpoint
}
}
func WithHTTPServer (srv *http .Server ) SSEOption {
return func (s *SSEServer ) {
s .srv = srv
}
}
func WithKeepAliveInterval (keepAliveInterval time .Duration ) SSEOption {
return func (s *SSEServer ) {
s .keepAlive = true
s .keepAliveInterval = keepAliveInterval
}
}
func WithKeepAlive (keepAlive bool ) SSEOption {
return func (s *SSEServer ) {
s .keepAlive = keepAlive
}
}
func WithSSEContextFunc (fn SSEContextFunc ) SSEOption {
return func (s *SSEServer ) {
s .contextFunc = fn
}
}
func WithSessionIDGenerator (fn SessionIDGenFunc ) SSEOption {
return func (s *SSEServer ) {
if fn != nil {
s .sessionIDGenFunc = fn
}
}
}
func NewSSEServer (server *MCPServer , opts ...SSEOption ) *SSEServer {
s := &SSEServer {
server : server ,
sseEndpoint : "/sse" ,
messageEndpoint : "/message" ,
useFullURLForMessageEndpoint : true ,
keepAlive : false ,
keepAliveInterval : 10 * time .Second ,
sessionIDGenFunc : func (ctx context .Context , r *http .Request ) (string , error ) {
return uuid .New ().String (), nil
},
}
for _ , opt := range opts {
opt (s )
}
return s
}
func NewTestServer (server *MCPServer , opts ...SSEOption ) *httptest .Server {
sseServer := NewSSEServer (server , opts ...)
testServer := httptest .NewServer (sseServer )
sseServer .baseURL = testServer .URL
return testServer
}
func (s *SSEServer ) Start (addr string ) error {
s .mu .Lock ()
if s .srv == nil {
s .srv = &http .Server {
Addr : addr ,
Handler : s ,
}
} else {
if s .srv .Addr == "" {
s .srv .Addr = addr
} else if s .srv .Addr != addr {
return fmt .Errorf ("conflicting listen address: WithHTTPServer(%q) vs Start(%q)" , s .srv .Addr , addr )
}
}
srv := s .srv
s .mu .Unlock ()
return srv .ListenAndServe ()
}
func (s *SSEServer ) Shutdown (ctx context .Context ) error {
s .mu .RLock ()
srv := s .srv
s .mu .RUnlock ()
if srv != nil {
s .sessions .Range (func (key , value any ) bool {
if session , ok := value .(*sseSession ); ok {
close (session .done )
}
s .sessions .Delete (key )
return true
})
return srv .Shutdown (ctx )
}
return nil
}
func (s *SSEServer ) handleSSE (w http .ResponseWriter , r *http .Request ) {
if r .Method != http .MethodGet {
http .Error (w , "Method not allowed" , http .StatusMethodNotAllowed )
return
}
w .Header ().Set ("Content-Type" , "text/event-stream" )
w .Header ().Set ("Cache-Control" , "no-cache" )
w .Header ().Set ("Connection" , "keep-alive" )
w .Header ().Set ("Access-Control-Allow-Origin" , "*" )
flusher , ok := w .(http .Flusher )
if !ok {
http .Error (w , "Streaming unsupported" , http .StatusInternalServerError )
return
}
sessionID , err := s .sessionIDGenFunc (r .Context (), r )
if err != nil {
http .Error (w , "Failed to create session ID" , http .StatusInternalServerError )
return
}
if sessionID == "" {
http .Error (w , "Failed to create session ID" , http .StatusInternalServerError )
return
}
session := &sseSession {
done : make (chan struct {}),
eventQueue : make (chan string , 100 ),
sessionID : sessionID ,
notificationChannel : make (chan mcp .JSONRPCNotification , 100 ),
}
s .sessions .Store (sessionID , session )
defer s .sessions .Delete (sessionID )
if err := s .server .RegisterSession (r .Context (), session ); err != nil {
http .Error (
w ,
fmt .Sprintf ("Session registration failed: %v" , err ),
http .StatusInternalServerError ,
)
return
}
defer s .server .UnregisterSession (r .Context (), sessionID )
go func () {
for {
select {
case notification := <- session .notificationChannel :
eventData , err := json .Marshal (notification )
if err == nil {
select {
case session .eventQueue <- fmt .Sprintf ("event: message\ndata: %s\n\n" , eventData ):
case <- session .done :
return
}
}
case <- session .done :
return
case <- r .Context ().Done ():
return
}
}
}()
if s .keepAlive {
go func () {
ticker := time .NewTicker (s .keepAliveInterval )
defer ticker .Stop ()
for {
select {
case <- ticker .C :
message := mcp .JSONRPCRequest {
JSONRPC : "2.0" ,
ID : mcp .NewRequestId (session .requestID .Add (1 )),
Request : mcp .Request {
Method : "ping" ,
},
}
messageBytes , _ := json .Marshal (message )
pingMsg := fmt .Sprintf ("event: message\ndata:%s\n\n" , messageBytes )
select {
case session .eventQueue <- pingMsg :
case <- session .done :
return
}
case <- session .done :
return
case <- r .Context ().Done ():
return
}
}
}()
}
endpoint := s .GetMessageEndpointForClient (r , sessionID )
if s .appendQueryToMessageEndpoint && len (r .URL .RawQuery ) > 0 {
endpoint += "&" + r .URL .RawQuery
}
fmt .Fprintf (w , "event: endpoint\ndata: %s\r\n\r\n" , endpoint )
flusher .Flush ()
for {
select {
case event := <- session .eventQueue :
fmt .Fprint (w , event )
flusher .Flush ()
case <- r .Context ().Done ():
close (session .done )
return
case <- session .done :
return
}
}
}
func (s *SSEServer ) GetMessageEndpointForClient (r *http .Request , sessionID string ) string {
basePath := s .basePath
if s .dynamicBasePathFunc != nil {
basePath = s .dynamicBasePathFunc (r , sessionID )
}
endpointPath := normalizeURLPath (basePath , s .messageEndpoint )
if s .useFullURLForMessageEndpoint && s .baseURL != "" {
endpointPath = s .baseURL + endpointPath
}
return fmt .Sprintf ("%s?sessionId=%s" , endpointPath , sessionID )
}
func (s *SSEServer ) handleMessage (w http .ResponseWriter , r *http .Request ) {
if r .Method != http .MethodPost {
s .writeJSONRPCError (w , nil , mcp .INVALID_REQUEST , "Method not allowed" )
return
}
sessionID := r .URL .Query ().Get ("sessionId" )
if sessionID == "" {
s .writeJSONRPCError (w , nil , mcp .INVALID_PARAMS , "Missing sessionId" )
return
}
sessionI , ok := s .sessions .Load (sessionID )
if !ok {
s .writeJSONRPCError (w , nil , mcp .INVALID_PARAMS , "Invalid session ID" )
return
}
session := sessionI .(*sseSession )
ctx := s .server .WithContext (r .Context (), session )
if s .contextFunc != nil {
ctx = s .contextFunc (ctx , r )
}
var rawMessage json .RawMessage
if err := json .NewDecoder (r .Body ).Decode (&rawMessage ); err != nil {
s .writeJSONRPCError (w , nil , mcp .PARSE_ERROR , "Parse error" )
return
}
detachedCtx := context .WithoutCancel (ctx )
w .WriteHeader (http .StatusAccepted )
messageCtx := context .WithValue (detachedCtx , requestHeader , r .Header )
messageCtx , cancel := context .WithCancel (messageCtx )
go func (ctx context .Context ) {
defer cancel ()
response := s .server .HandleMessage (ctx , rawMessage )
if response != nil {
var message string
if eventData , err := json .Marshal (response ); err != nil {
log .Printf ("failed to marshal response: %v" , err )
message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n"
} else {
message = fmt .Sprintf ("event: message\ndata: %s\n\n" , eventData )
}
select {
case session .eventQueue <- message :
case <- session .done :
default :
log .Printf ("Event queue full for session %s" , sessionID )
}
}
}(messageCtx )
}
func (s *SSEServer ) writeJSONRPCError (
w http .ResponseWriter ,
id any ,
code int ,
message string ,
) {
response := createErrorResponse (id , code , message )
w .Header ().Set ("Content-Type" , "application/json" )
w .WriteHeader (http .StatusBadRequest )
if err := json .NewEncoder (w ).Encode (response ); err != nil {
http .Error (
w ,
fmt .Sprintf ("Failed to encode response: %v" , err ),
http .StatusInternalServerError ,
)
return
}
}
func (s *SSEServer ) SendEventToSession (
sessionID string ,
event any ,
) error {
sessionI , ok := s .sessions .Load (sessionID )
if !ok {
return fmt .Errorf ("session not found: %s" , sessionID )
}
session := sessionI .(*sseSession )
eventData , err := json .Marshal (event )
if err != nil {
return err
}
select {
case session .eventQueue <- fmt .Sprintf ("event: message\ndata: %s\n\n" , eventData ):
return nil
case <- session .done :
return fmt .Errorf ("session closed" )
default :
return fmt .Errorf ("event queue full" )
}
}
func (s *SSEServer ) GetUrlPath (input string ) (string , error ) {
parse , err := url .Parse (input )
if err != nil {
return "" , fmt .Errorf ("failed to parse URL %s: %w" , input , err )
}
return parse .Path , nil
}
func (s *SSEServer ) CompleteSseEndpoint () (string , error ) {
if s .dynamicBasePathFunc != nil {
return "" , &ErrDynamicPathConfig {Method : "CompleteSseEndpoint" }
}
path := normalizeURLPath (s .basePath , s .sseEndpoint )
return s .baseURL + path , nil
}
func (s *SSEServer ) CompleteSsePath () string {
path , err := s .CompleteSseEndpoint ()
if err != nil {
return normalizeURLPath (s .basePath , s .sseEndpoint )
}
urlPath , err := s .GetUrlPath (path )
if err != nil {
return normalizeURLPath (s .basePath , s .sseEndpoint )
}
return urlPath
}
func (s *SSEServer ) CompleteMessageEndpoint () (string , error ) {
if s .dynamicBasePathFunc != nil {
return "" , &ErrDynamicPathConfig {Method : "CompleteMessageEndpoint" }
}
path := normalizeURLPath (s .basePath , s .messageEndpoint )
return s .baseURL + path , nil
}
func (s *SSEServer ) CompleteMessagePath () string {
path , err := s .CompleteMessageEndpoint ()
if err != nil {
return normalizeURLPath (s .basePath , s .messageEndpoint )
}
urlPath , err := s .GetUrlPath (path )
if err != nil {
return normalizeURLPath (s .basePath , s .messageEndpoint )
}
return urlPath
}
func (s *SSEServer ) SSEHandler () http .Handler {
return http .HandlerFunc (s .handleSSE )
}
func (s *SSEServer ) MessageHandler () http .Handler {
return http .HandlerFunc (s .handleMessage )
}
func (s *SSEServer ) ServeHTTP (w http .ResponseWriter , r *http .Request ) {
if s .dynamicBasePathFunc != nil {
http .Error (
w ,
(&ErrDynamicPathConfig {Method : "ServeHTTP" }).Error (),
http .StatusInternalServerError ,
)
return
}
path := r .URL .Path
ssePath := s .CompleteSsePath ()
if ssePath != "" && path == ssePath {
s .handleSSE (w , r )
return
}
messagePath := s .CompleteMessagePath ()
if messagePath != "" && path == messagePath {
s .handleMessage (w , r )
return
}
http .NotFound (w , r )
}
func normalizeURLPath(elem ...string ) string {
joined := path .Join (elem ...)
if !strings .HasPrefix (joined , "/" ) {
joined = "/" + joined
}
if len (joined ) > 1 && strings .HasSuffix (joined , "/" ) {
joined = joined [:len (joined )-1 ]
}
return joined
}
The pages are generated with Golds v0.8.4 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .