// Package multistream implements a simple stream router for the // multistream-select protocoli. The protocol is defined at // https://github.com/multiformats/multistream-select
package multistream import ( ) // ErrTooLarge is an error to signal that an incoming message was too large var ErrTooLarge = errors.New("incoming message was too large") // ProtocolID identifies the multistream protocol itself and makes sure // the multistream muxers on both sides of a channel can work with each other. const ProtocolID = "/multistream/1.0.0" var writerPool = sync.Pool{ New: func() interface{} { return bufio.NewWriter(nil) }, } // StringLike is an interface that supports all types with underlying type // string type StringLike interface { ~string } // HandlerFunc is a user-provided function used by the MultistreamMuxer to // handle a protocol/stream. type HandlerFunc[ StringLike] func(protocol , rwc io.ReadWriteCloser) error // Handler is a wrapper to HandlerFunc which attaches a name (protocol) and a // match function which can optionally be used to select a handler by other // means than the name. type Handler[ StringLike] struct { MatchFunc func() bool Handle HandlerFunc[] AddName } // MultistreamMuxer is a muxer for multistream. Depending on the stream // protocol tag it will select the right handler and hand the stream off to it. type MultistreamMuxer[ StringLike] struct { handlerlock sync.RWMutex handlers []Handler[] } // NewMultistreamMuxer creates a muxer. func [ StringLike]() *MultistreamMuxer[] { return new(MultistreamMuxer[]) } // LazyConn is the connection type returned by the lazy negotiation functions. type LazyConn interface { io.ReadWriteCloser // Flush flushes the lazy negotiation, if any. Flush() error } func writeUvarint( io.Writer, uint64) error { := make([]byte, 16) := varint.PutUvarint(, ) , := .Write([:]) if != nil { return } return nil } func delimWriteBuffered( io.Writer, []byte) error { := getWriter() defer putWriter() := delimWrite(, ) if != nil { return } return .Flush() } func delitmWriteAll( io.Writer, ...[]byte) error { for , := range { if := delimWrite(, ); != nil { return fmt.Errorf("failed to write messages %s, err: %v ", string(), ) } } return nil } func delimWrite( io.Writer, []byte) error { := writeUvarint(, uint64(len()+1)) if != nil { return } _, = .Write() if != nil { return } _, = .Write([]byte{'\n'}) if != nil { return } return nil } func fulltextMatch[ StringLike]( ) func() bool { return func( ) bool { return == } } // AddHandler attaches a new protocol handler to the muxer. func ( *MultistreamMuxer[]) ( , HandlerFunc[]) { .AddHandlerWithFunc(, fulltextMatch(), ) } // AddHandlerWithFunc attaches a new protocol handler to the muxer with a match. // If the match function returns true for a given protocol tag, the protocol // will be selected even if the handler name and protocol tags are different. func ( *MultistreamMuxer[]) ( , func() bool, HandlerFunc[]) { .handlerlock.Lock() defer .handlerlock.Unlock() .removeHandler() .handlers = append(.handlers, Handler[]{ MatchFunc: , Handle: , AddName: , }) } // RemoveHandler removes the handler with the given name from the muxer. func ( *MultistreamMuxer[]) ( ) { .handlerlock.Lock() defer .handlerlock.Unlock() .removeHandler() } func ( *MultistreamMuxer[]) ( ) { for , := range .handlers { if .AddName == { .handlers = append(.handlers[:], .handlers[+1:]...) return } } } // Protocols returns the list of handler-names added to this this muxer. func ( *MultistreamMuxer[]) () [] { .handlerlock.RLock() defer .handlerlock.RUnlock() var [] for , := range .handlers { = append(, .AddName) } return } // ErrIncorrectVersion is an error reported when the muxer protocol negotiation // fails because of a ProtocolID mismatch. var ErrIncorrectVersion = errors.New("client connected with incorrect version") func ( *MultistreamMuxer[]) ( ) *Handler[] { .handlerlock.RLock() defer .handlerlock.RUnlock() for , := range .handlers { if .MatchFunc() { return & } } return nil } // Negotiate performs protocol selection and returns the protocol name and // the matching handler function for it (or an error). func ( *MultistreamMuxer[]) ( io.ReadWriteCloser) ( , HandlerFunc[], error) { defer func() { if := recover(); != nil { fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", , debug.Stack()) = fmt.Errorf("panic in multistream negotiation: %s", ) } }() // Send the multistream protocol ID // Ignore the error here. We want the handshake to finish, even if the // other side has closed this rwc for writing. They may have sent us a // message and closed. Future writers will get an error anyways. _ = delimWriteBuffered(, []byte(ProtocolID)) , := ReadNextToken[]() if != nil { return "", nil, } if != ProtocolID { .Close() return "", nil, ErrIncorrectVersion } : for { // Now read and respond to commands until they send a valid protocol id , := ReadNextToken[]() if != nil { return "", nil, } := .findHandler() if == nil { if := delimWriteBuffered(, []byte("na")); != nil { return "", nil, } continue } // Ignore the error here. We want the handshake to finish, even if the // other side has closed this rwc for writing. They may have sent us a // message and closed. Future writers will get an error anyways. _ = delimWriteBuffered(, []byte()) // hand off processing to the sub-protocol handler return , .Handle, nil } } // Handle performs protocol negotiation on a ReadWriteCloser // (i.e. a connection). It will find a matching handler for the // incoming protocol and pass the ReadWriteCloser to it. func ( *MultistreamMuxer[]) ( io.ReadWriteCloser) error { , , := .Negotiate() if != nil { return } return (, ) } // ReadNextToken extracts a token from a Reader. It is used during // protocol negotiation and returns a string. func [ StringLike]( io.Reader) (, error) { , := ReadNextTokenBytes() if != nil { return "", } return (), nil } // ReadNextTokenBytes extracts a token from a Reader. It is used // during protocol negotiation and returns a byte slice. func ( io.Reader) ([]byte, error) { , := lpReadBuf() switch { case nil: return , nil case ErrTooLarge: return nil, ErrTooLarge default: return nil, } } func lpReadBuf( io.Reader) ([]byte, error) { , := .(io.ByteReader) if ! { = &byteReader{} } , := varint.ReadUvarint() if != nil { return nil, } if > 1024 { return nil, ErrTooLarge } := make([]byte, ) _, = io.ReadFull(, ) if != nil { if == io.EOF { = io.ErrUnexpectedEOF } return nil, } if len() == 0 || [-1] != '\n' { return nil, errors.New("message did not have trailing newline") } // slice off the trailing newline = [:-1] return , nil } // byteReader implements the ByteReader interface that ReadUVarint requires type byteReader struct { io.Reader } func ( *byteReader) () (byte, error) { var [1]byte , := .Read([:]) if == 1 { return [0], nil } if == nil { if != 0 { panic("read more bytes than buffer size") } = io.ErrNoProgress } return 0, } func getWriter( io.Writer) *bufio.Writer { := writerPool.Get().(*bufio.Writer) .Reset() return } func putWriter( *bufio.Writer) { .Reset(nil) writerPool.Put() }