// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package dtls

import (
	
	
	
	
	
	
	

	
	
	
	
	
)

// [RFC6347 Section-4.2.4]
//                      +-----------+
//                +---> | PREPARING | <--------------------+
//                |     +-----------+                      |
//                |           |                            |
//                |           | Buffer next flight         |
//                |           |                            |
//                |          \|/                           |
//                |     +-----------+                      |
//                |     |  SENDING  |<------------------+  | Send
//                |     +-----------+                   |  | HelloRequest
//        Receive |           |                         |  |
//           next |           | Send flight             |  | or
//         flight |  +--------+                         |  |
//                |  |        | Set retransmit timer    |  | Receive
//                |  |       \|/                        |  | HelloRequest
//                |  |  +-----------+                   |  | Send
//                +--)--|  WAITING  |-------------------+  | ClientHello
//                |  |  +-----------+   Timer expires   |  |
//                |  |         |                        |  |
//                |  |         +------------------------+  |
//        Receive |  | Send           Read retransmit      |
//           last |  | last                                |
//         flight |  | flight                              |
//                |  |                                     |
//               \|/\|/                                    |
//            +-----------+                                |
//            | FINISHED  | -------------------------------+
//            +-----------+
//                 |  /|\
//                 |   |
//                 +---+
//              Read retransmit
//           Retransmit last flight

type handshakeState uint8

const (
	handshakeErrored handshakeState = iota
	handshakePreparing
	handshakeSending
	handshakeWaiting
	handshakeFinished
)

func ( handshakeState) () string {
	switch  {
	case handshakeErrored:
		return "Errored"
	case handshakePreparing:
		return "Preparing"
	case handshakeSending:
		return "Sending"
	case handshakeWaiting:
		return "Waiting"
	case handshakeFinished:
		return "Finished"
	default:
		return "Unknown"
	}
}

type handshakeFSM struct {
	currentFlight      flightVal
	flights            []*packet
	retransmit         bool
	retransmitInterval time.Duration
	state              *State
	cache              *handshakeCache
	cfg                *handshakeConfig
	closed             chan struct{}
}

type handshakeConfig struct {
	localPSKCallback             PSKCallback
	localPSKIdentityHint         []byte
	localCipherSuites            []CipherSuite             // Available CipherSuites
	localSignatureSchemes        []signaturehash.Algorithm // Available signature schemes
	extendedMasterSecret         ExtendedMasterSecretType  // Policy for the Extended Master Support extension
	localSRTPProtectionProfiles  []SRTPProtectionProfile   // Available SRTPProtectionProfiles, if empty no SRTP support
	localSRTPMasterKeyIdentifier []byte
	serverName                   string
	supportedProtocols           []string
	clientAuth                   ClientAuthType // If we are a client should we request a client certificate
	localCertificates            []tls.Certificate
	nameToCertificate            map[string]*tls.Certificate
	insecureSkipVerify           bool
	verifyPeerCertificate        func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
	verifyConnection             func(*State) error
	sessionStore                 SessionStore
	rootCAs                      *x509.CertPool
	clientCAs                    *x509.CertPool
	initialRetransmitInterval    time.Duration
	disableRetransmitBackoff     bool
	customCipherSuites           func() []CipherSuite
	ellipticCurves               []elliptic.Curve
	insecureSkipHelloVerify      bool
	connectionIDGenerator        func() []byte
	helloRandomBytesGenerator    func() [handshake.RandomBytesLength]byte

	onFlightState func(flightVal, handshakeState)
	log           logging.LeveledLogger
	keyLogWriter  io.Writer

	localGetCertificate       func(*ClientHelloInfo) (*tls.Certificate, error)
	localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error)

	initialEpoch uint16

	mu sync.Mutex

	clientHelloMessageHook        func(handshake.MessageClientHello) handshake.Message
	serverHelloMessageHook        func(handshake.MessageServerHello) handshake.Message
	certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message

	resumeState *State
}

type flightConn interface {
	notify(ctx context.Context, level alert.Level, desc alert.Description) error
	writePackets(context.Context, []*packet) error
	recvHandshake() <-chan recvHandshakeState
	setLocalEpoch(epoch uint16)
	handleQueuedPackets(context.Context) error
	sessionKey() []byte
}

func ( *handshakeConfig) ( string, ,  []byte) {
	if .keyLogWriter == nil {
		return
	}
	.mu.Lock()
	defer .mu.Unlock()
	,  := .keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", , , )))
	if  != nil {
		.log.Debugf("failed to write key log file: %s", )
	}
}

func srvCliStr( bool) string {
	if  {
		return "client"
	}

	return "server"
}

func newHandshakeFSM(
	 *State,  *handshakeCache,  *handshakeConfig,
	 flightVal,
) *handshakeFSM {
	return &handshakeFSM{
		currentFlight:      ,
		state:              ,
		cache:              ,
		cfg:                ,
		retransmitInterval: .initialRetransmitInterval,
		closed:             make(chan struct{}),
	}
}

func ( *handshakeFSM) ( context.Context,  flightConn,  handshakeState) error {
	 := 
	defer func() {
		close(.closed)
	}()
	for {
		.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(.state.isClient), .currentFlight.String(), .String())
		if .cfg.onFlightState != nil {
			.cfg.onFlightState(.currentFlight, )
		}
		var  error
		switch  {
		case handshakePreparing:
			,  = .prepare(, )
		case handshakeSending:
			,  = .send(, )
		case handshakeWaiting:
			,  = .wait(, )
		case handshakeFinished:
			,  = .finish(, )
		default:
			return errInvalidFSMTransition
		}
		if  != nil {
			return 
		}
	}
}

func ( *handshakeFSM) () <-chan struct{} {
	return .closed
}

func ( *handshakeFSM) ( context.Context,  flightConn) (handshakeState, error) {
	.flights = nil
	// Prepare flights
	var (
		 *alert.Alert
		       error
		      []*packet
	)
	, ,  := .currentFlight.getFlightGenerator()
	if  != nil {
		 = 
		 = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}
	} else {
		, ,  = (, .state, .cache, .cfg)
		.retransmit = 
	}
	if  != nil {
		if  := .notify(, .Level, .Description);  != nil {
			if  != nil {
				 = 
			}
		}
	}
	if  != nil {
		return handshakeErrored, 
	}

	.flights = 
	 := .cfg.initialEpoch
	 := 
	for ,  := range .flights {
		.record.Header.Epoch += 
		if .record.Header.Epoch >  {
			 = .record.Header.Epoch
		}
		if ,  := .record.Content.(*handshake.Handshake);  {
			.Header.MessageSequence = uint16(.state.handshakeSendSequence) //nolint:gosec // G115
			.state.handshakeSendSequence++
		}
	}
	if  !=  {
		.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(.state.isClient), )
		.setLocalEpoch()
	}

	return handshakeSending, nil
}

func ( *handshakeFSM) ( context.Context,  flightConn) (handshakeState, error) {
	// Send flights
	if  := .writePackets(, .flights);  != nil {
		return handshakeErrored, 
	}

	if .currentFlight.isLastSendFlight() {
		return handshakeFinished, nil
	}

	return handshakeWaiting, nil
}

func ( *handshakeFSM) ( context.Context,  flightConn) (handshakeState, error) { //nolint:gocognit,cyclop
	,  := .currentFlight.getFlightParser()
	if  != nil {
		if  := .notify(, alert.Fatal, alert.InternalError);  != nil {
			return handshakeErrored, 
		}

		return handshakeErrored, 
	}

	 := time.NewTimer(.retransmitInterval)
	for {
		select {
		case  := <-.recvHandshake():
			if .isRetransmit {
				close(.done)

				return handshakeSending, nil
			}

			, ,  := (, , .state, .cache, .cfg)
			.retransmitInterval = .cfg.initialRetransmitInterval
			close(.done)
			if  != nil {
				if  := .notify(, .Level, .Description);  != nil {
					if  != nil {
						 = 
					}
				}
			}
			if  != nil {
				return handshakeErrored, 
			}
			if  == 0 {
				break
			}
			.cfg.log.Tracef(
				"[handshake:%s] %s -> %s",
				srvCliStr(.state.isClient),
				.currentFlight.String(),
				.String(),
			)
			if .isLastRecvFlight() && .currentFlight ==  {
				return handshakeFinished, nil
			}
			.currentFlight = 

			return handshakePreparing, nil

		case <-.C:
			if !.retransmit {
				return handshakeWaiting, nil
			}

			// RFC 4347 4.2.4.1:
			// Implementations SHOULD use an initial timer value of 1 second (the minimum defined in RFC 2988 [RFC2988])
			// and double the value at each retransmission, up to no less than the RFC 2988 maximum of 60 seconds.
			if !.cfg.disableRetransmitBackoff {
				.retransmitInterval *= 2
			}
			if .retransmitInterval > time.Second*60 {
				.retransmitInterval = time.Second * 60
			}

			return handshakeSending, nil
		case <-.Done():
			.retransmitInterval = .cfg.initialRetransmitInterval

			return handshakeErrored, .Err()
		}
	}
}

func ( *handshakeFSM) ( context.Context,  flightConn) (handshakeState, error) {
	select {
	case  := <-.recvHandshake():
		close(.done)
		if .state.isClient {
			return handshakeFinished, nil
		} else {
			return handshakeSending, nil
		}
	case <-.Done():
		return handshakeErrored, .Err()
	}
}