package ssh
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
gossh "golang.org/x/crypto/ssh"
)
var ErrServerClosed = errors .New ("ssh: Server closed" )
type SubsystemHandler func (s Session )
var DefaultSubsystemHandlers = map [string ]SubsystemHandler {}
type RequestHandler func (ctx Context , srv *Server , req *gossh .Request ) (ok bool , payload []byte )
var DefaultRequestHandlers = map [string ]RequestHandler {}
type ChannelHandler func (srv *Server , conn *gossh .ServerConn , newChan gossh .NewChannel , ctx Context )
var DefaultChannelHandlers = map [string ]ChannelHandler {
"session" : DefaultSessionHandler ,
}
type Server struct {
Addr string
Handler Handler
HostSigners []Signer
Version string
Banner string
BannerHandler BannerHandler
KeyboardInteractiveHandler KeyboardInteractiveHandler
PasswordHandler PasswordHandler
PublicKeyHandler PublicKeyHandler
PtyCallback PtyCallback
ConnCallback ConnCallback
LocalPortForwardingCallback LocalPortForwardingCallback
ReversePortForwardingCallback ReversePortForwardingCallback
ServerConfigCallback ServerConfigCallback
SessionRequestCallback SessionRequestCallback
ConnectionFailedCallback ConnectionFailedCallback
IdleTimeout time .Duration
MaxTimeout time .Duration
ChannelHandlers map [string ]ChannelHandler
RequestHandlers map [string ]RequestHandler
SubsystemHandlers map [string ]SubsystemHandler
listenerWg sync .WaitGroup
mu sync .RWMutex
listeners map [net .Listener ]struct {}
conns map [*gossh .ServerConn ]struct {}
connWg sync .WaitGroup
doneChan chan struct {}
}
func (srv *Server ) ensureHostSigner () error {
srv .mu .Lock ()
defer srv .mu .Unlock ()
if len (srv .HostSigners ) == 0 {
signer , err := generateSigner ()
if err != nil {
return err
}
srv .HostSigners = append (srv .HostSigners , signer )
}
return nil
}
func (srv *Server ) ensureHandlers () {
srv .mu .Lock ()
defer srv .mu .Unlock ()
if srv .RequestHandlers == nil {
srv .RequestHandlers = map [string ]RequestHandler {}
for k , v := range DefaultRequestHandlers {
srv .RequestHandlers [k ] = v
}
}
if srv .ChannelHandlers == nil {
srv .ChannelHandlers = map [string ]ChannelHandler {}
for k , v := range DefaultChannelHandlers {
srv .ChannelHandlers [k ] = v
}
}
if srv .SubsystemHandlers == nil {
srv .SubsystemHandlers = map [string ]SubsystemHandler {}
for k , v := range DefaultSubsystemHandlers {
srv .SubsystemHandlers [k ] = v
}
}
}
func (srv *Server ) config (ctx Context ) *gossh .ServerConfig {
srv .mu .RLock ()
defer srv .mu .RUnlock ()
var config *gossh .ServerConfig
if srv .ServerConfigCallback == nil {
config = &gossh .ServerConfig {}
} else {
config = srv .ServerConfigCallback (ctx )
}
for _ , signer := range srv .HostSigners {
config .AddHostKey (signer )
}
if srv .PasswordHandler == nil && srv .PublicKeyHandler == nil && srv .KeyboardInteractiveHandler == nil {
config .NoClientAuth = true
}
if srv .Version != "" {
config .ServerVersion = "SSH-2.0-" + srv .Version
}
if srv .Banner != "" {
config .BannerCallback = func (_ gossh .ConnMetadata ) string {
return srv .Banner
}
}
if srv .BannerHandler != nil {
config .BannerCallback = func (conn gossh .ConnMetadata ) string {
applyConnMetadata (ctx , conn )
return srv .BannerHandler (ctx )
}
}
if srv .PasswordHandler != nil {
config .PasswordCallback = func (conn gossh .ConnMetadata , password []byte ) (*gossh .Permissions , error ) {
applyConnMetadata (ctx , conn )
if ok := srv .PasswordHandler (ctx , string (password )); !ok {
return ctx .Permissions ().Permissions , fmt .Errorf ("permission denied" )
}
return ctx .Permissions ().Permissions , nil
}
}
if srv .PublicKeyHandler != nil {
config .PublicKeyCallback = func (conn gossh .ConnMetadata , key gossh .PublicKey ) (*gossh .Permissions , error ) {
applyConnMetadata (ctx , conn )
if ok := srv .PublicKeyHandler (ctx , key ); !ok {
return ctx .Permissions ().Permissions , fmt .Errorf ("permission denied" )
}
ctx .SetValue (ContextKeyPublicKey , key )
return ctx .Permissions ().Permissions , nil
}
}
if srv .KeyboardInteractiveHandler != nil {
config .KeyboardInteractiveCallback = func (conn gossh .ConnMetadata , challenger gossh .KeyboardInteractiveChallenge ) (*gossh .Permissions , error ) {
applyConnMetadata (ctx , conn )
if ok := srv .KeyboardInteractiveHandler (ctx , challenger ); !ok {
return ctx .Permissions ().Permissions , fmt .Errorf ("permission denied" )
}
return ctx .Permissions ().Permissions , nil
}
}
return config
}
func (srv *Server ) Handle (fn Handler ) {
srv .mu .Lock ()
defer srv .mu .Unlock ()
srv .Handler = fn
}
func (srv *Server ) Close () error {
srv .mu .Lock ()
defer srv .mu .Unlock ()
srv .closeDoneChanLocked ()
err := srv .closeListenersLocked ()
for c := range srv .conns {
c .Close ()
delete (srv .conns , c )
}
return err
}
func (srv *Server ) Shutdown (ctx context .Context ) error {
srv .mu .Lock ()
lnerr := srv .closeListenersLocked ()
srv .closeDoneChanLocked ()
srv .mu .Unlock ()
finished := make (chan struct {}, 1 )
go func () {
srv .listenerWg .Wait ()
srv .connWg .Wait ()
finished <- struct {}{}
}()
select {
case <- ctx .Done ():
return ctx .Err ()
case <- finished :
return lnerr
}
}
func (srv *Server ) Serve (l net .Listener ) error {
srv .ensureHandlers ()
defer l .Close ()
if err := srv .ensureHostSigner (); err != nil {
return err
}
if srv .Handler == nil {
srv .Handler = DefaultHandler
}
var tempDelay time .Duration
srv .trackListener (l , true )
defer srv .trackListener (l , false )
for {
conn , e := l .Accept ()
if e != nil {
select {
case <- srv .getDoneChan ():
return ErrServerClosed
default :
}
if ne , ok := e .(net .Error ); ok && ne .Temporary () {
if tempDelay == 0 {
tempDelay = 5 * time .Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time .Second ; tempDelay > max {
tempDelay = max
}
time .Sleep (tempDelay )
continue
}
return e
}
go srv .HandleConn (conn )
}
}
func (srv *Server ) HandleConn (newConn net .Conn ) {
ctx , cancel := newContext (srv )
if srv .ConnCallback != nil {
cbConn := srv .ConnCallback (ctx , newConn )
if cbConn == nil {
newConn .Close ()
return
}
newConn = cbConn
}
conn := &serverConn {
Conn : newConn ,
idleTimeout : srv .IdleTimeout ,
closeCanceler : cancel ,
}
if srv .MaxTimeout > 0 {
conn .maxDeadline = time .Now ().Add (srv .MaxTimeout )
}
defer conn .Close ()
sshConn , chans , reqs , err := gossh .NewServerConn (conn , srv .config (ctx ))
if err != nil {
if srv .ConnectionFailedCallback != nil {
srv .ConnectionFailedCallback (conn , err )
}
return
}
srv .trackConn (sshConn , true )
defer srv .trackConn (sshConn , false )
ctx .SetValue (ContextKeyConn , sshConn )
applyConnMetadata (ctx , sshConn )
go srv .handleRequests (ctx , reqs )
for ch := range chans {
handler := srv .ChannelHandlers [ch .ChannelType ()]
if handler == nil {
handler = srv .ChannelHandlers ["default" ]
}
if handler == nil {
ch .Reject (gossh .UnknownChannelType , "unsupported channel type" )
continue
}
go handler (srv , sshConn , ch , ctx )
}
}
func (srv *Server ) handleRequests (ctx Context , in <-chan *gossh .Request ) {
for req := range in {
handler := srv .RequestHandlers [req .Type ]
if handler == nil {
handler = srv .RequestHandlers ["default" ]
}
if handler == nil {
req .Reply (false , nil )
continue
}
ret , payload := handler (ctx , srv , req )
req .Reply (ret , payload )
}
}
func (srv *Server ) ListenAndServe () error {
addr := srv .Addr
if addr == "" {
addr = ":22"
}
ln , err := net .Listen ("tcp" , addr )
if err != nil {
return err
}
return srv .Serve (ln )
}
func (srv *Server ) AddHostKey (key Signer ) {
srv .mu .Lock ()
defer srv .mu .Unlock ()
for i , k := range srv .HostSigners {
if k .PublicKey ().Type () == key .PublicKey ().Type () {
srv .HostSigners [i ] = key
return
}
}
srv .HostSigners = append (srv .HostSigners , key )
}
func (srv *Server ) SetOption (option Option ) error {
return option (srv )
}
func (srv *Server ) getDoneChan () <-chan struct {} {
srv .mu .Lock ()
defer srv .mu .Unlock ()
return srv .getDoneChanLocked ()
}
func (srv *Server ) getDoneChanLocked () chan struct {} {
if srv .doneChan == nil {
srv .doneChan = make (chan struct {})
}
return srv .doneChan
}
func (srv *Server ) closeDoneChanLocked () {
ch := srv .getDoneChanLocked ()
select {
case <- ch :
default :
close (ch )
}
}
func (srv *Server ) closeListenersLocked () error {
var err error
for ln := range srv .listeners {
if cerr := ln .Close (); cerr != nil && err == nil {
err = cerr
}
delete (srv .listeners , ln )
}
return err
}
func (srv *Server ) trackListener (ln net .Listener , add bool ) {
srv .mu .Lock ()
defer srv .mu .Unlock ()
if srv .listeners == nil {
srv .listeners = make (map [net .Listener ]struct {})
}
if add {
if len (srv .listeners ) == 0 && len (srv .conns ) == 0 {
srv .doneChan = nil
}
srv .listeners [ln ] = struct {}{}
srv .listenerWg .Add (1 )
} else {
delete (srv .listeners , ln )
srv .listenerWg .Done ()
}
}
func (srv *Server ) trackConn (c *gossh .ServerConn , add bool ) {
srv .mu .Lock ()
defer srv .mu .Unlock ()
if srv .conns == nil {
srv .conns = make (map [*gossh .ServerConn ]struct {})
}
if add {
srv .conns [c ] = struct {}{}
srv .connWg .Add (1 )
} else {
delete (srv .conns , c )
srv .connWg .Done ()
}
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .