package ssh
import (
"errors"
"fmt"
"io"
"log"
"net"
"slices"
"strings"
"sync"
)
const debugHandshake = false
const chanSize = 16
const maxPendingPackets = 64
type keyingTransport interface {
packetConn
prepareKeyChange(*NegotiatedAlgorithms , *kexResult ) error
setStrictMode() error
setInitialKEXDone()
}
type handshakeTransport struct {
conn keyingTransport
config *Config
serverVersion []byte
clientVersion []byte
hostKeys []Signer
publicKeyAuthAlgorithms []string
hostKeyAlgorithms []string
incoming chan []byte
readError error
mu sync .Mutex
writeCond *sync .Cond
writeError error
sentInitPacket []byte
sentInitMsg *kexInitMsg
pendingPackets [][]byte
writePacketsLeft uint32
writeBytesLeft int64
userAuthComplete bool
requestKex chan struct {}
startKex chan *pendingKex
kexLoopDone chan struct {}
hostKeyCallback HostKeyCallback
dialAddress string
remoteAddr net .Addr
bannerCallback BannerCallback
algorithms *NegotiatedAlgorithms
readPacketsLeft uint32
readBytesLeft int64
sessionID []byte
strictMode bool
}
type pendingKex struct {
otherInit []byte
done chan error
}
func newHandshakeTransport(conn keyingTransport , config *Config , clientVersion , serverVersion []byte ) *handshakeTransport {
t := &handshakeTransport {
conn : conn ,
serverVersion : serverVersion ,
clientVersion : clientVersion ,
incoming : make (chan []byte , chanSize ),
requestKex : make (chan struct {}, 1 ),
startKex : make (chan *pendingKex ),
kexLoopDone : make (chan struct {}),
config : config ,
}
t .writeCond = sync .NewCond (&t .mu )
t .resetReadThresholds ()
t .resetWriteThresholds ()
t .requestKex <- struct {}{}
return t
}
func newClientTransport(conn keyingTransport , clientVersion , serverVersion []byte , config *ClientConfig , dialAddr string , addr net .Addr ) *handshakeTransport {
t := newHandshakeTransport (conn , &config .Config , clientVersion , serverVersion )
t .dialAddress = dialAddr
t .remoteAddr = addr
t .hostKeyCallback = config .HostKeyCallback
t .bannerCallback = config .BannerCallback
if config .HostKeyAlgorithms != nil {
t .hostKeyAlgorithms = config .HostKeyAlgorithms
} else {
t .hostKeyAlgorithms = defaultHostKeyAlgos
}
go t .readLoop ()
go t .kexLoop ()
return t
}
func newServerTransport(conn keyingTransport , clientVersion , serverVersion []byte , config *ServerConfig ) *handshakeTransport {
t := newHandshakeTransport (conn , &config .Config , clientVersion , serverVersion )
t .hostKeys = config .hostKeys
t .publicKeyAuthAlgorithms = config .PublicKeyAuthAlgorithms
go t .readLoop ()
go t .kexLoop ()
return t
}
func (t *handshakeTransport ) getSessionID () []byte {
return t .sessionID
}
func (t *handshakeTransport ) getAlgorithms () NegotiatedAlgorithms {
return *t .algorithms
}
func (t *handshakeTransport ) waitSession () error {
p , err := t .readPacket ()
if err != nil {
return err
}
if p [0 ] != msgNewKeys {
return fmt .Errorf ("ssh: first packet should be msgNewKeys" )
}
return nil
}
func (t *handshakeTransport ) id () string {
if len (t .hostKeys ) > 0 {
return "server"
}
return "client"
}
func (t *handshakeTransport ) printPacket (p []byte , write bool ) {
action := "got"
if write {
action = "sent"
}
if p [0 ] == msgChannelData || p [0 ] == msgChannelExtendedData {
log .Printf ("%s %s data (packet %d bytes)" , t .id (), action , len (p ))
} else {
msg , err := decode (p )
log .Printf ("%s %s %T %v (%v)" , t .id (), action , msg , msg , err )
}
}
func (t *handshakeTransport ) readPacket () ([]byte , error ) {
p , ok := <-t .incoming
if !ok {
return nil , t .readError
}
return p , nil
}
func (t *handshakeTransport ) readLoop () {
first := true
for {
p , err := t .readOnePacket (first )
first = false
if err != nil {
t .readError = err
close (t .incoming )
break
}
if !(t .sessionID == nil && t .strictMode ) && (p [0 ] == msgIgnore || p [0 ] == msgDebug ) {
continue
}
t .incoming <- p
}
t .recordWriteError (t .readError )
close (t .startKex )
}
func (t *handshakeTransport ) pushPacket (p []byte ) error {
if debugHandshake {
t .printPacket (p , true )
}
return t .conn .writePacket (p )
}
func (t *handshakeTransport ) getWriteError () error {
t .mu .Lock ()
defer t .mu .Unlock ()
return t .writeError
}
func (t *handshakeTransport ) recordWriteError (err error ) {
t .mu .Lock ()
defer t .mu .Unlock ()
if t .writeError == nil && err != nil {
t .writeError = err
t .writeCond .Broadcast ()
}
}
func (t *handshakeTransport ) requestKeyExchange () {
select {
case t .requestKex <- struct {}{}:
default :
}
}
func (t *handshakeTransport ) resetWriteThresholds () {
t .writePacketsLeft = packetRekeyThreshold
if t .config .RekeyThreshold > 0 {
t .writeBytesLeft = int64 (t .config .RekeyThreshold )
} else if t .algorithms != nil {
t .writeBytesLeft = t .algorithms .Write .rekeyBytes ()
} else {
t .writeBytesLeft = 1 << 30
}
}
func (t *handshakeTransport ) kexLoop () {
write :
for t .getWriteError () == nil {
var request *pendingKex
var sent bool
for request == nil || !sent {
var ok bool
select {
case request , ok = <- t .startKex :
if !ok {
break write
}
case <- t .requestKex :
break
}
if !sent {
if err := t .sendKexInit (); err != nil {
t .recordWriteError (err )
break
}
sent = true
}
}
if err := t .getWriteError (); err != nil {
if request != nil {
request .done <- err
}
break
}
err := t .enterKeyExchange (request .otherInit )
t .mu .Lock ()
t .writeError = err
t .sentInitPacket = nil
t .sentInitMsg = nil
t .resetWriteThresholds ()
clear :
for {
select {
case <- t .requestKex :
default :
break clear
}
}
request .done <- t .writeError
for _ , p := range t .pendingPackets {
t .writeError = t .pushPacket (p )
if t .writeError != nil {
break
}
}
t .pendingPackets = t .pendingPackets [:0 ]
t .writeCond .Broadcast ()
t .mu .Unlock ()
}
t .conn .Close ()
for request := range t .startKex {
request .done <- t .getWriteError ()
}
close (t .kexLoopDone )
}
const packetRekeyThreshold = (1 << 31 )
func (t *handshakeTransport ) resetReadThresholds () {
t .readPacketsLeft = packetRekeyThreshold
if t .config .RekeyThreshold > 0 {
t .readBytesLeft = int64 (t .config .RekeyThreshold )
} else if t .algorithms != nil {
t .readBytesLeft = t .algorithms .Read .rekeyBytes ()
} else {
t .readBytesLeft = 1 << 30
}
}
func (t *handshakeTransport ) readOnePacket (first bool ) ([]byte , error ) {
p , err := t .conn .readPacket ()
if err != nil {
return nil , err
}
if t .readPacketsLeft > 0 {
t .readPacketsLeft --
} else {
t .requestKeyExchange ()
}
if t .readBytesLeft > 0 {
t .readBytesLeft -= int64 (len (p ))
} else {
t .requestKeyExchange ()
}
if debugHandshake {
t .printPacket (p , false )
}
if first && p [0 ] != msgKexInit {
return nil , fmt .Errorf ("ssh: first packet should be msgKexInit" )
}
if p [0 ] != msgKexInit {
return p , nil
}
firstKex := t .sessionID == nil
kex := pendingKex {
done : make (chan error , 1 ),
otherInit : p ,
}
t .startKex <- &kex
err = <-kex .done
if debugHandshake {
log .Printf ("%s exited key exchange (first %v), err %v" , t .id (), firstKex , err )
}
if err != nil {
return nil , err
}
t .resetReadThresholds ()
successPacket := []byte {msgIgnore }
if firstKex {
successPacket = []byte {msgNewKeys }
}
return successPacket , nil
}
const (
kexStrictClient = "kex-strict-c-v00@openssh.com"
kexStrictServer = "kex-strict-s-v00@openssh.com"
)
func (t *handshakeTransport ) sendKexInit () error {
t .mu .Lock ()
defer t .mu .Unlock ()
if t .sentInitMsg != nil {
return nil
}
msg := &kexInitMsg {
CiphersClientServer : t .config .Ciphers ,
CiphersServerClient : t .config .Ciphers ,
MACsClientServer : t .config .MACs ,
MACsServerClient : t .config .MACs ,
CompressionClientServer : supportedCompressions ,
CompressionServerClient : supportedCompressions ,
}
io .ReadFull (t .config .Rand , msg .Cookie [:])
msg .KexAlgos = make ([]string , 0 , len (t .config .KeyExchanges )+2 )
msg .KexAlgos = append (msg .KexAlgos , t .config .KeyExchanges ...)
isServer := len (t .hostKeys ) > 0
if isServer {
for _ , k := range t .hostKeys {
keyFormat := k .PublicKey ().Type ()
switch s := k .(type ) {
case MultiAlgorithmSigner :
for _ , algo := range algorithmsForKeyFormat (keyFormat ) {
if slices .Contains (s .Algorithms (), underlyingAlgo (algo )) {
msg .ServerHostKeyAlgos = append (msg .ServerHostKeyAlgos , algo )
}
}
case AlgorithmSigner :
msg .ServerHostKeyAlgos = append (msg .ServerHostKeyAlgos , algorithmsForKeyFormat (keyFormat )...)
default :
msg .ServerHostKeyAlgos = append (msg .ServerHostKeyAlgos , keyFormat )
}
}
if t .sessionID == nil {
msg .KexAlgos = append (msg .KexAlgos , kexStrictServer )
}
} else {
msg .ServerHostKeyAlgos = t .hostKeyAlgorithms
if firstKeyExchange := t .sessionID == nil ; firstKeyExchange {
msg .KexAlgos = append (msg .KexAlgos , "ext-info-c" )
msg .KexAlgos = append (msg .KexAlgos , kexStrictClient )
}
}
packet := Marshal (msg )
packetCopy := make ([]byte , len (packet ))
copy (packetCopy , packet )
if err := t .pushPacket (packetCopy ); err != nil {
return err
}
t .sentInitMsg = msg
t .sentInitPacket = packet
return nil
}
var errSendBannerPhase = errors .New ("ssh: SendAuthBanner outside of authentication phase" )
func (t *handshakeTransport ) writePacket (p []byte ) error {
t .mu .Lock ()
defer t .mu .Unlock ()
switch p [0 ] {
case msgKexInit :
return errors .New ("ssh: only handshakeTransport can send kexInit" )
case msgNewKeys :
return errors .New ("ssh: only handshakeTransport can send newKeys" )
case msgUserAuthBanner :
if t .userAuthComplete {
return errSendBannerPhase
}
case msgUserAuthSuccess :
t .userAuthComplete = true
}
if t .writeError != nil {
return t .writeError
}
if t .sentInitMsg != nil {
if len (t .pendingPackets ) < maxPendingPackets {
cp := make ([]byte , len (p ))
copy (cp , p )
t .pendingPackets = append (t .pendingPackets , cp )
return nil
}
for t .sentInitMsg != nil {
t .writeCond .Wait ()
if t .writeError != nil {
return t .writeError
}
}
}
if t .writeBytesLeft > 0 {
t .writeBytesLeft -= int64 (len (p ))
} else {
t .requestKeyExchange ()
}
if t .writePacketsLeft > 0 {
t .writePacketsLeft --
} else {
t .requestKeyExchange ()
}
if err := t .pushPacket (p ); err != nil {
t .writeError = err
t .writeCond .Broadcast ()
}
return nil
}
func (t *handshakeTransport ) Close () error {
err := t .conn .Close ()
<-t .kexLoopDone
return err
}
func (t *handshakeTransport ) enterKeyExchange (otherInitPacket []byte ) error {
if debugHandshake {
log .Printf ("%s entered key exchange" , t .id ())
}
otherInit := &kexInitMsg {}
if err := Unmarshal (otherInitPacket , otherInit ); err != nil {
return err
}
magics := handshakeMagics {
clientVersion : t .clientVersion ,
serverVersion : t .serverVersion ,
clientKexInit : otherInitPacket ,
serverKexInit : t .sentInitPacket ,
}
clientInit := otherInit
serverInit := t .sentInitMsg
isClient := len (t .hostKeys ) == 0
if isClient {
clientInit , serverInit = serverInit , clientInit
magics .clientKexInit = t .sentInitPacket
magics .serverKexInit = otherInitPacket
}
var err error
t .algorithms , err = findAgreedAlgorithms (isClient , clientInit , serverInit )
if err != nil {
return err
}
if t .sessionID == nil && ((isClient && slices .Contains (serverInit .KexAlgos , kexStrictServer )) || (!isClient && slices .Contains (clientInit .KexAlgos , kexStrictClient ))) {
t .strictMode = true
if err := t .conn .setStrictMode (); err != nil {
return err
}
}
if otherInit .FirstKexFollows && (clientInit .KexAlgos [0 ] != serverInit .KexAlgos [0 ] || clientInit .ServerHostKeyAlgos [0 ] != serverInit .ServerHostKeyAlgos [0 ]) {
if _ , err := t .conn .readPacket (); err != nil {
return err
}
}
kex , ok := kexAlgoMap [t .algorithms .KeyExchange ]
if !ok {
return fmt .Errorf ("ssh: unexpected key exchange algorithm %v" , t .algorithms .KeyExchange )
}
var result *kexResult
if len (t .hostKeys ) > 0 {
result , err = t .server (kex , &magics )
} else {
result , err = t .client (kex , &magics )
}
if err != nil {
return err
}
firstKeyExchange := t .sessionID == nil
if firstKeyExchange {
t .sessionID = result .H
}
result .SessionID = t .sessionID
if err := t .conn .prepareKeyChange (t .algorithms , result ); err != nil {
return err
}
if err = t .conn .writePacket ([]byte {msgNewKeys }); err != nil {
return err
}
if !isClient && firstKeyExchange && slices .Contains (clientInit .KexAlgos , "ext-info-c" ) {
supportedPubKeyAuthAlgosList := strings .Join (t .publicKeyAuthAlgorithms , "," )
extInfo := &extInfoMsg {
NumExtensions : 2 ,
Payload : make ([]byte , 0 , 4 +15 +4 +len (supportedPubKeyAuthAlgosList )+4 +16 +4 +1 ),
}
extInfo .Payload = appendInt (extInfo .Payload , len ("server-sig-algs" ))
extInfo .Payload = append (extInfo .Payload , "server-sig-algs" ...)
extInfo .Payload = appendInt (extInfo .Payload , len (supportedPubKeyAuthAlgosList ))
extInfo .Payload = append (extInfo .Payload , supportedPubKeyAuthAlgosList ...)
extInfo .Payload = appendInt (extInfo .Payload , len ("ping@openssh.com" ))
extInfo .Payload = append (extInfo .Payload , "ping@openssh.com" ...)
extInfo .Payload = appendInt (extInfo .Payload , 1 )
extInfo .Payload = append (extInfo .Payload , "0" ...)
if err := t .conn .writePacket (Marshal (extInfo )); err != nil {
return err
}
}
if packet , err := t .conn .readPacket (); err != nil {
return err
} else if packet [0 ] != msgNewKeys {
return unexpectedMessageError (msgNewKeys , packet [0 ])
}
if firstKeyExchange {
t .conn .setInitialKEXDone ()
}
return nil
}
type algorithmSignerWrapper struct {
Signer
}
func (a algorithmSignerWrapper ) SignWithAlgorithm (rand io .Reader , data []byte , algorithm string ) (*Signature , error ) {
if algorithm != underlyingAlgo (a .PublicKey ().Type ()) {
return nil , errors .New ("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm" )
}
return a .Sign (rand , data )
}
func pickHostKey(hostKeys []Signer , algo string ) AlgorithmSigner {
for _ , k := range hostKeys {
if s , ok := k .(MultiAlgorithmSigner ); ok {
if !slices .Contains (s .Algorithms (), underlyingAlgo (algo )) {
continue
}
}
if algo == k .PublicKey ().Type () {
return algorithmSignerWrapper {k }
}
k , ok := k .(AlgorithmSigner )
if !ok {
continue
}
for _ , a := range algorithmsForKeyFormat (k .PublicKey ().Type ()) {
if algo == a {
return k
}
}
}
return nil
}
func (t *handshakeTransport ) server (kex kexAlgorithm , magics *handshakeMagics ) (*kexResult , error ) {
hostKey := pickHostKey (t .hostKeys , t .algorithms .HostKey )
if hostKey == nil {
return nil , errors .New ("ssh: internal error: negotiated unsupported signature type" )
}
r , err := kex .Server (t .conn , t .config .Rand , magics , hostKey , t .algorithms .HostKey )
return r , err
}
func (t *handshakeTransport ) client (kex kexAlgorithm , magics *handshakeMagics ) (*kexResult , error ) {
result , err := kex .Client (t .conn , t .config .Rand , magics )
if err != nil {
return nil , err
}
hostKey , err := ParsePublicKey (result .HostKey )
if err != nil {
return nil , err
}
if err := verifyHostKeySignature (hostKey , t .algorithms .HostKey , result ); err != nil {
return nil , err
}
err = t .hostKeyCallback (t .dialAddress , t .remoteAddr , hostKey )
if err != nil {
return nil , err
}
return result , nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .