package multistream
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"runtime/debug"
"sync"
"github.com/multiformats/go-varint"
)
var ErrTooLarge = errors .New ("incoming message was too large" )
const ProtocolID = "/multistream/1.0.0"
var writerPool = sync .Pool {
New : func () interface {} {
return bufio .NewWriter (nil )
},
}
type StringLike interface {
~string
}
type HandlerFunc [T StringLike ] func (protocol T , rwc io .ReadWriteCloser ) error
type Handler [T StringLike ] struct {
MatchFunc func (T ) bool
Handle HandlerFunc [T ]
AddName T
}
type MultistreamMuxer [T StringLike ] struct {
handlerlock sync .RWMutex
handlers []Handler [T ]
}
func NewMultistreamMuxer [T StringLike ]() *MultistreamMuxer [T ] {
return new (MultistreamMuxer [T ])
}
type LazyConn interface {
io .ReadWriteCloser
Flush () error
}
func writeUvarint(w io .Writer , i uint64 ) error {
varintbuf := make ([]byte , 16 )
n := varint .PutUvarint (varintbuf , i )
_ , err := w .Write (varintbuf [:n ])
if err != nil {
return err
}
return nil
}
func delimWriteBuffered(w io .Writer , mes []byte ) error {
bw := getWriter (w )
defer putWriter (bw )
err := delimWrite (bw , mes )
if err != nil {
return err
}
return bw .Flush ()
}
func delitmWriteAll(w io .Writer , messages ...[]byte ) error {
for _ , mes := range messages {
if err := delimWrite (w , mes ); err != nil {
return fmt .Errorf ("failed to write messages %s, err: %v " , string (mes ), err )
}
}
return nil
}
func delimWrite(w io .Writer , mes []byte ) error {
err := writeUvarint (w , uint64 (len (mes )+1 ))
if err != nil {
return err
}
_, err = w .Write (mes )
if err != nil {
return err
}
_, err = w .Write ([]byte {'\n' })
if err != nil {
return err
}
return nil
}
func fulltextMatch[T StringLike ](s T ) func (T ) bool {
return func (a T ) bool {
return a == s
}
}
func (msm *MultistreamMuxer [T ]) AddHandler (protocol T , handler HandlerFunc [T ]) {
msm .AddHandlerWithFunc (protocol , fulltextMatch (protocol ), handler )
}
func (msm *MultistreamMuxer [T ]) AddHandlerWithFunc (protocol T , match func (T ) bool , handler HandlerFunc [T ]) {
msm .handlerlock .Lock ()
defer msm .handlerlock .Unlock ()
msm .removeHandler (protocol )
msm .handlers = append (msm .handlers , Handler [T ]{
MatchFunc : match ,
Handle : handler ,
AddName : protocol ,
})
}
func (msm *MultistreamMuxer [T ]) RemoveHandler (protocol T ) {
msm .handlerlock .Lock ()
defer msm .handlerlock .Unlock ()
msm .removeHandler (protocol )
}
func (msm *MultistreamMuxer [T ]) removeHandler (protocol T ) {
for i , h := range msm .handlers {
if h .AddName == protocol {
msm .handlers = append (msm .handlers [:i ], msm .handlers [i +1 :]...)
return
}
}
}
func (msm *MultistreamMuxer [T ]) Protocols () []T {
msm .handlerlock .RLock ()
defer msm .handlerlock .RUnlock ()
var out []T
for _ , h := range msm .handlers {
out = append (out , h .AddName )
}
return out
}
var ErrIncorrectVersion = errors .New ("client connected with incorrect version" )
func (msm *MultistreamMuxer [T ]) findHandler (proto T ) *Handler [T ] {
msm .handlerlock .RLock ()
defer msm .handlerlock .RUnlock ()
for _ , h := range msm .handlers {
if h .MatchFunc (proto ) {
return &h
}
}
return nil
}
func (msm *MultistreamMuxer [T ]) Negotiate (rwc io .ReadWriteCloser ) (proto T , handler HandlerFunc [T ], err error ) {
defer func () {
if rerr := recover (); rerr != nil {
fmt .Fprintf (os .Stderr , "caught panic: %s\n%s\n" , rerr , debug .Stack ())
err = fmt .Errorf ("panic in multistream negotiation: %s" , rerr )
}
}()
_ = delimWriteBuffered (rwc , []byte (ProtocolID ))
line , err := ReadNextToken [T ](rwc )
if err != nil {
return "" , nil , err
}
if line != ProtocolID {
rwc .Close ()
return "" , nil , ErrIncorrectVersion
}
loop :
for {
tok , err := ReadNextToken [T ](rwc )
if err != nil {
return "" , nil , err
}
h := msm .findHandler (tok )
if h == nil {
if err := delimWriteBuffered (rwc , []byte ("na" )); err != nil {
return "" , nil , err
}
continue loop
}
_ = delimWriteBuffered (rwc , []byte (tok ))
return tok , h .Handle , nil
}
}
func (msm *MultistreamMuxer [T ]) Handle (rwc io .ReadWriteCloser ) error {
p , h , err := msm .Negotiate (rwc )
if err != nil {
return err
}
return h (p , rwc )
}
func ReadNextToken [T StringLike ](r io .Reader ) (T , error ) {
tok , err := ReadNextTokenBytes (r )
if err != nil {
return "" , err
}
return T (tok ), nil
}
func ReadNextTokenBytes (r io .Reader ) ([]byte , error ) {
data , err := lpReadBuf (r )
switch err {
case nil :
return data , nil
case ErrTooLarge :
return nil , ErrTooLarge
default :
return nil , err
}
}
func lpReadBuf(r io .Reader ) ([]byte , error ) {
br , ok := r .(io .ByteReader )
if !ok {
br = &byteReader {r }
}
length , err := varint .ReadUvarint (br )
if err != nil {
return nil , err
}
if length > 1024 {
return nil , ErrTooLarge
}
buf := make ([]byte , length )
_, err = io .ReadFull (r , buf )
if err != nil {
if err == io .EOF {
err = io .ErrUnexpectedEOF
}
return nil , err
}
if len (buf ) == 0 || buf [length -1 ] != '\n' {
return nil , errors .New ("message did not have trailing newline" )
}
buf = buf [:length -1 ]
return buf , nil
}
type byteReader struct {
io .Reader
}
func (br *byteReader ) ReadByte () (byte , error ) {
var b [1 ]byte
n , err := br .Read (b [:])
if n == 1 {
return b [0 ], nil
}
if err == nil {
if n != 0 {
panic ("read more bytes than buffer size" )
}
err = io .ErrNoProgress
}
return 0 , err
}
func getWriter(w io .Writer ) *bufio .Writer {
bw := writerPool .Get ().(*bufio .Writer )
bw .Reset (w )
return bw
}
func putWriter(bw *bufio .Writer ) {
bw .Reset (nil )
writerPool .Put (bw )
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .