package ssh
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net"
"sync"
"time"
gossh "golang.org/x/crypto/ssh"
)
var ErrServerClosed = errors .New ("ssh: Server closed" )
type SubsystemHandler func (s Session )
var DefaultSubsystemHandlers = map [string ]SubsystemHandler {}
type RequestHandler func (ctx Context , srv *Server , req *gossh .Request ) (ok bool , payload []byte )
var DefaultRequestHandlers = map [string ]RequestHandler {}
type ChannelHandler func (srv *Server , conn *gossh .ServerConn , newChan gossh .NewChannel , ctx Context )
var DefaultChannelHandlers = map [string ]ChannelHandler {
"session" : DefaultSessionHandler ,
}
var permissionsPublicKeyExt = "gliderlabs/ssh.PublicKey"
func ensureNoPKInPermissions(ctx Context ) error {
if _ , ok := ctx .Permissions ().Permissions .Extensions [permissionsPublicKeyExt ]; ok {
return errors .New ("misconfigured server: public key incorrectly set" )
}
return nil
}
type Server struct {
Addr string
Handler Handler
HostSigners []Signer
Version string
Banner string
BannerHandler BannerHandler
KeyboardInteractiveHandler KeyboardInteractiveHandler
PasswordHandler PasswordHandler
PublicKeyHandler PublicKeyHandler
PtyCallback PtyCallback
PtyHandler PtyHandler
ConnCallback ConnCallback
LocalPortForwardingCallback LocalPortForwardingCallback
ReversePortForwardingCallback ReversePortForwardingCallback
ServerConfigCallback ServerConfigCallback
SessionRequestCallback SessionRequestCallback
ConnectionFailedCallback ConnectionFailedCallback
IdleTimeout time .Duration
MaxTimeout time .Duration
ChannelHandlers map [string ]ChannelHandler
RequestHandlers map [string ]RequestHandler
SubsystemHandlers map [string ]SubsystemHandler
listenerWg sync .WaitGroup
mu sync .RWMutex
listeners map [net .Listener ]struct {}
conns map [*gossh .ServerConn ]struct {}
connWg sync .WaitGroup
doneChan chan struct {}
}
func (srv *Server ) ensureHostSigner () error {
srv .mu .Lock ()
defer srv .mu .Unlock ()
if len (srv .HostSigners ) == 0 {
signer , err := generateSigner ()
if err != nil {
return err
}
srv .HostSigners = append (srv .HostSigners , signer )
}
return nil
}
func (srv *Server ) ensureHandlers () {
srv .mu .Lock ()
defer srv .mu .Unlock ()
if srv .RequestHandlers == nil {
srv .RequestHandlers = map [string ]RequestHandler {}
for k , v := range DefaultRequestHandlers {
srv .RequestHandlers [k ] = v
}
}
if srv .ChannelHandlers == nil {
srv .ChannelHandlers = map [string ]ChannelHandler {}
for k , v := range DefaultChannelHandlers {
srv .ChannelHandlers [k ] = v
}
}
if srv .SubsystemHandlers == nil {
srv .SubsystemHandlers = map [string ]SubsystemHandler {}
for k , v := range DefaultSubsystemHandlers {
srv .SubsystemHandlers [k ] = v
}
}
}
func (srv *Server ) config (ctx Context ) *gossh .ServerConfig {
srv .mu .Lock ()
defer srv .mu .Unlock ()
var config *gossh .ServerConfig
if srv .ServerConfigCallback == nil {
config = &gossh .ServerConfig {}
} else {
config = srv .ServerConfigCallback (ctx )
}
for _ , signer := range srv .HostSigners {
config .AddHostKey (signer )
}
if srv .PasswordHandler == nil && srv .PublicKeyHandler == nil && srv .KeyboardInteractiveHandler == nil {
config .NoClientAuth = true
}
if srv .PtyHandler == nil {
srv .PtyHandler = emulatePtyHandler
}
if srv .Version != "" {
config .ServerVersion = "SSH-2.0-" + srv .Version
}
if srv .Banner != "" {
config .BannerCallback = func (_ gossh .ConnMetadata ) string {
return srv .Banner
}
}
if srv .BannerHandler != nil {
config .BannerCallback = func (conn gossh .ConnMetadata ) string {
applyConnMetadata (ctx , conn )
return srv .BannerHandler (ctx )
}
}
if srv .PasswordHandler != nil {
config .PasswordCallback = func (conn gossh .ConnMetadata , password []byte ) (*gossh .Permissions , error ) {
resetPermissions (ctx )
applyConnMetadata (ctx , conn )
err := ensureNoPKInPermissions (ctx )
if err != nil {
return ctx .Permissions ().Permissions , err
}
ok := srv .PasswordHandler (ctx , string (password ))
if !ok {
return ctx .Permissions ().Permissions , fmt .Errorf ("permission denied" )
}
return ctx .Permissions ().Permissions , nil
}
}
if srv .PublicKeyHandler != nil {
config .PublicKeyCallback = func (conn gossh .ConnMetadata , key gossh .PublicKey ) (*gossh .Permissions , error ) {
resetPermissions (ctx )
applyConnMetadata (ctx , conn )
err := ensureNoPKInPermissions (ctx )
if err != nil {
return ctx .Permissions ().Permissions , err
}
ok := srv .PublicKeyHandler (ctx , key )
if !ok {
return ctx .Permissions ().Permissions , fmt .Errorf ("permission denied" )
}
pkStr := base64 .StdEncoding .EncodeToString (key .Marshal ())
if ctx .Permissions ().Permissions .Extensions == nil {
ctx .Permissions ().Permissions .Extensions = map [string ]string {}
}
ctx .Permissions ().Permissions .Extensions [permissionsPublicKeyExt ] = pkStr
return ctx .Permissions ().Permissions , nil
}
}
if srv .KeyboardInteractiveHandler != nil {
config .KeyboardInteractiveCallback = func (conn gossh .ConnMetadata , challenger gossh .KeyboardInteractiveChallenge ) (*gossh .Permissions , error ) {
resetPermissions (ctx )
applyConnMetadata (ctx , conn )
ok := srv .KeyboardInteractiveHandler (ctx , challenger )
err := ensureNoPKInPermissions (ctx )
if err != nil {
return ctx .Permissions ().Permissions , err
}
if !ok {
return ctx .Permissions ().Permissions , fmt .Errorf ("permission denied" )
}
return ctx .Permissions ().Permissions , nil
}
}
return config
}
func (srv *Server ) Handle (fn Handler ) {
srv .mu .Lock ()
defer srv .mu .Unlock ()
srv .Handler = fn
}
func (srv *Server ) Close () error {
srv .mu .Lock ()
defer srv .mu .Unlock ()
srv .closeDoneChanLocked ()
err := srv .closeListenersLocked ()
for c := range srv .conns {
c .Close ()
delete (srv .conns , c )
}
return err
}
func (srv *Server ) Shutdown (ctx context .Context ) error {
srv .mu .Lock ()
lnerr := srv .closeListenersLocked ()
srv .closeDoneChanLocked ()
srv .mu .Unlock ()
finished := make (chan struct {}, 1 )
go func () {
srv .listenerWg .Wait ()
srv .connWg .Wait ()
finished <- struct {}{}
}()
select {
case <- ctx .Done ():
return ctx .Err ()
case <- finished :
return lnerr
}
}
func (srv *Server ) Serve (l net .Listener ) error {
srv .ensureHandlers ()
defer l .Close ()
if err := srv .ensureHostSigner (); err != nil {
return err
}
if srv .Handler == nil {
srv .Handler = DefaultHandler
}
var tempDelay time .Duration
srv .trackListener (l , true )
defer srv .trackListener (l , false )
for {
conn , e := l .Accept ()
if e != nil {
select {
case <- srv .getDoneChan ():
return ErrServerClosed
default :
}
if ne , ok := e .(net .Error ); ok && ne .Temporary () {
if tempDelay == 0 {
tempDelay = 5 * time .Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time .Second ; tempDelay > max {
tempDelay = max
}
time .Sleep (tempDelay )
continue
}
return e
}
go srv .HandleConn (conn )
}
}
func (srv *Server ) HandleConn (newConn net .Conn ) {
ctx , cancel := newContext (srv )
if srv .ConnCallback != nil {
cbConn := srv .ConnCallback (ctx , newConn )
if cbConn == nil {
newConn .Close ()
return
}
newConn = cbConn
}
conn := &serverConn {
Conn : newConn ,
idleTimeout : srv .IdleTimeout ,
closeCanceler : cancel ,
}
if srv .MaxTimeout > 0 {
conn .maxDeadline = time .Now ().Add (srv .MaxTimeout )
}
defer conn .Close ()
sshConn , chans , reqs , err := gossh .NewServerConn (conn , srv .config (ctx ))
if err != nil {
if srv .ConnectionFailedCallback != nil {
srv .ConnectionFailedCallback (conn , err )
}
return
}
if sshConn .Permissions != nil {
if keyData , ok := sshConn .Permissions .Extensions [permissionsPublicKeyExt ]; ok {
decodedData , err := base64 .StdEncoding .DecodeString (keyData )
if err != nil {
if srv .ConnectionFailedCallback != nil {
srv .ConnectionFailedCallback (conn , err )
}
return
}
key , err := gossh .ParsePublicKey (decodedData )
if err != nil {
if srv .ConnectionFailedCallback != nil {
srv .ConnectionFailedCallback (conn , err )
}
return
}
ctx .SetValue (ContextKeyPublicKey , key )
}
}
ctx .Permissions ().Permissions = sshConn .Permissions
srv .trackConn (sshConn , true )
defer srv .trackConn (sshConn , false )
ctx .SetValue (ContextKeyConn , sshConn )
applyConnMetadata (ctx , sshConn )
go srv .handleRequests (ctx , reqs )
for ch := range chans {
handler := srv .ChannelHandlers [ch .ChannelType ()]
if handler == nil {
handler = srv .ChannelHandlers ["default" ]
}
if handler == nil {
ch .Reject (gossh .UnknownChannelType , "unsupported channel type" )
continue
}
go handler (srv , sshConn , ch , ctx )
}
}
func (srv *Server ) handleRequests (ctx Context , in <-chan *gossh .Request ) {
for req := range in {
handler := srv .RequestHandlers [req .Type ]
if handler == nil {
handler = srv .RequestHandlers ["default" ]
}
if handler == nil {
req .Reply (false , nil )
continue
}
ret , payload := handler (ctx , srv , req )
req .Reply (ret , payload )
}
}
func (srv *Server ) ListenAndServe () error {
addr := srv .Addr
if addr == "" {
addr = ":22"
}
ln , err := net .Listen ("tcp" , addr )
if err != nil {
return err
}
return srv .Serve (ln )
}
func (srv *Server ) AddHostKey (key Signer ) {
srv .mu .Lock ()
defer srv .mu .Unlock ()
for i , k := range srv .HostSigners {
if k .PublicKey ().Type () == key .PublicKey ().Type () {
srv .HostSigners [i ] = key
return
}
}
srv .HostSigners = append (srv .HostSigners , key )
}
func (srv *Server ) SetOption (option Option ) error {
return option (srv )
}
func (srv *Server ) getDoneChan () <-chan struct {} {
srv .mu .Lock ()
defer srv .mu .Unlock ()
return srv .getDoneChanLocked ()
}
func (srv *Server ) getDoneChanLocked () chan struct {} {
if srv .doneChan == nil {
srv .doneChan = make (chan struct {})
}
return srv .doneChan
}
func (srv *Server ) closeDoneChanLocked () {
ch := srv .getDoneChanLocked ()
select {
case <- ch :
default :
close (ch )
}
}
func (srv *Server ) closeListenersLocked () error {
var err error
for ln := range srv .listeners {
if cerr := ln .Close (); cerr != nil && err == nil {
err = cerr
}
delete (srv .listeners , ln )
}
return err
}
func (srv *Server ) trackListener (ln net .Listener , add bool ) {
srv .mu .Lock ()
defer srv .mu .Unlock ()
if srv .listeners == nil {
srv .listeners = make (map [net .Listener ]struct {})
}
if add {
if len (srv .listeners ) == 0 && len (srv .conns ) == 0 {
srv .doneChan = nil
}
srv .listeners [ln ] = struct {}{}
srv .listenerWg .Add (1 )
} else {
delete (srv .listeners , ln )
srv .listenerWg .Done ()
}
}
func (srv *Server ) trackConn (c *gossh .ServerConn , add bool ) {
srv .mu .Lock ()
defer srv .mu .Unlock ()
if srv .conns == nil {
srv .conns = make (map [*gossh .ServerConn ]struct {})
}
if add {
srv .conns [c ] = struct {}{}
srv .connWg .Add (1 )
} else {
delete (srv .conns , c )
srv .connWg .Done ()
}
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .