package autonatv2

import (
	
	
	
	
	
	
	
	

	pool 
	
	
	
	
	
	

	

	ma 
	manet 
)

var (
	errResourceLimitExceeded = errors.New("resource limit exceeded")
	errBadRequest            = errors.New("bad request")
	errDialDataRefused       = errors.New("dial data refused")
)

type dataRequestPolicyFunc = func(observedAddr, dialAddr ma.Multiaddr) bool

type EventDialRequestCompleted struct {
	Error            error
	ResponseStatus   pb.DialResponse_ResponseStatus
	DialStatus       pb.DialStatus
	DialDataRequired bool
	DialedAddr       ma.Multiaddr
}

// server implements the AutoNATv2 server.
// It can ask client to provide dial data before attempting the requested dial.
// It rate limits requests on a global level, per peer level and on whether the request requires dial data.
type server struct {
	host       host.Host
	dialerHost host.Host
	limiter    *rateLimiter

	// dialDataRequestPolicy is used to determine whether dialing the address requires receiving
	// dial data. It is set to amplification attack prevention by default.
	dialDataRequestPolicy                dataRequestPolicyFunc
	amplificatonAttackPreventionDialWait time.Duration
	metricsTracer                        MetricsTracer

	// for tests
	now               func() time.Time
	allowPrivateAddrs bool
}

func newServer( host.Host,  *autoNATSettings) *server {
	return &server{
		dialerHost:                           ,
		dialDataRequestPolicy:                .dataRequestPolicy,
		amplificatonAttackPreventionDialWait: .amplificatonAttackPreventionDialWait,
		allowPrivateAddrs:                    .allowPrivateAddrs,
		limiter: &rateLimiter{
			RPM:                          .serverRPM,
			PerPeerRPM:                   .serverPerPeerRPM,
			DialDataRPM:                  .serverDialDataRPM,
			MaxConcurrentRequestsPerPeer: .maxConcurrentRequestsPerPeer,
			now:                          .now,
		},
		now:           .now,
		metricsTracer: .metricsTracer,
	}
}

// Enable attaches the stream handler to the host.
func ( *server) ( host.Host) {
	.host = 
	.host.SetStreamHandler(DialProtocol, .handleDialRequest)
}

func ( *server) () {
	.host.RemoveStreamHandler(DialProtocol)
	.dialerHost.Close()
	.limiter.Close()
}

// handleDialRequest is the dial-request protocol stream handler
func ( *server) ( network.Stream) {
	defer func() {
		if  := recover();  != nil {
			fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", , debug.Stack())
			.Reset()
		}
	}()

	log.Debugf("received dial-request from: %s, addr: %s", .Conn().RemotePeer(), .Conn().RemoteMultiaddr())
	 := .serveDialRequest()
	log.Debugf("completed dial-request from %s, response status: %s, dial status: %s, err: %s",
		.Conn().RemotePeer(), .ResponseStatus, .DialStatus, .Error)
	if .metricsTracer != nil {
		.metricsTracer.CompletedRequest()
	}
}

func ( *server) ( network.Stream) EventDialRequestCompleted {
	if  := .Scope().SetService(ServiceName);  != nil {
		.Reset()
		log.Debugf("failed to attach stream to %s service: %w", ServiceName, )
		return EventDialRequestCompleted{
			Error: errors.New("failed to attach stream to autonat-v2"),
		}
	}

	if  := .Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways);  != nil {
		.Reset()
		log.Debugf("failed to reserve memory for stream %s: %w", DialProtocol, )
		return EventDialRequestCompleted{Error: errResourceLimitExceeded}
	}
	defer .Scope().ReleaseMemory(maxMsgSize)

	 := .now().Add(streamTimeout)
	,  := context.WithDeadline(context.Background(), )
	defer ()
	.SetDeadline(.now().Add(streamTimeout))
	defer .Close()

	 := .Conn().RemotePeer()

	var  pb.Message
	 := pbio.NewDelimitedWriter()
	// Check for rate limit before parsing the request
	if !.limiter.Accept() {
		 = pb.Message{
			Msg: &pb.Message_DialResponse{
				DialResponse: &pb.DialResponse{
					Status: pb.DialResponse_E_REQUEST_REJECTED,
				},
			},
		}
		if  := .WriteMsg(&);  != nil {
			.Reset()
			log.Debugf("failed to write request rejected response to %s: %s", , )
			return EventDialRequestCompleted{
				ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED,
				Error:          fmt.Errorf("write failed: %w", ),
			}
		}
		log.Debugf("rejected request from %s: rate limit exceeded", )
		return EventDialRequestCompleted{ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED}
	}
	defer .limiter.CompleteRequest()

	 := pbio.NewDelimitedReader(, maxMsgSize)
	if  := .ReadMsg(&);  != nil {
		.Reset()
		log.Debugf("failed to read request from %s: %s", , )
		return EventDialRequestCompleted{Error: fmt.Errorf("read failed: %w", )}
	}
	if .GetDialRequest() == nil {
		.Reset()
		log.Debugf("invalid message type from %s: %T expected: DialRequest", , .Msg)
		return EventDialRequestCompleted{Error: errBadRequest}
	}

	// parse peer's addresses
	var  ma.Multiaddr
	var  int
	for ,  := range .GetDialRequest().GetAddrs() {
		if  >= maxPeerAddresses {
			break
		}
		,  := ma.NewMultiaddrBytes()
		if  != nil {
			continue
		}
		if !.allowPrivateAddrs && !manet.IsPublicAddr() {
			continue
		}
		if !.dialerHost.Network().CanDial(, ) {
			continue
		}
		 = 
		 = 
		break
	}
	// No dialable address
	if  == nil {
		 = pb.Message{
			Msg: &pb.Message_DialResponse{
				DialResponse: &pb.DialResponse{
					Status: pb.DialResponse_E_DIAL_REFUSED,
				},
			},
		}
		if  := .WriteMsg(&);  != nil {
			.Reset()
			log.Debugf("failed to write dial refused response to %s: %s", , )
			return EventDialRequestCompleted{
				ResponseStatus: pb.DialResponse_E_DIAL_REFUSED,
				Error:          fmt.Errorf("write failed: %w", ),
			}
		}
		return EventDialRequestCompleted{
			ResponseStatus: pb.DialResponse_E_DIAL_REFUSED,
		}
	}

	 := .GetDialRequest().Nonce

	 := .dialDataRequestPolicy(.Conn().RemoteMultiaddr(), )
	if  && !.limiter.AcceptDialDataRequest() {
		 = pb.Message{
			Msg: &pb.Message_DialResponse{
				DialResponse: &pb.DialResponse{
					Status: pb.DialResponse_E_REQUEST_REJECTED,
				},
			},
		}
		if  := .WriteMsg(&);  != nil {
			.Reset()
			log.Debugf("failed to write request rejected response to %s: %s", , )
			return EventDialRequestCompleted{
				ResponseStatus:   pb.DialResponse_E_REQUEST_REJECTED,
				Error:            fmt.Errorf("write failed: %w", ),
				DialDataRequired: true,
			}
		}
		log.Debugf("rejected request from %s: rate limit exceeded", )
		return EventDialRequestCompleted{
			ResponseStatus:   pb.DialResponse_E_REQUEST_REJECTED,
			DialDataRequired: true,
		}
	}

	if  {
		if  := getDialData(, , &, );  != nil {
			.Reset()
			log.Debugf("%s refused dial data request: %s", , )
			return EventDialRequestCompleted{
				Error:            errDialDataRefused,
				DialDataRequired: true,
				DialedAddr:       ,
			}
		}
		// wait for a bit to prevent thundering herd style attacks on a victim
		 := time.Duration(rand.Intn(int(.amplificatonAttackPreventionDialWait) + 1)) // the range is [0, n)
		 := time.NewTimer()
		defer .Stop()
		select {
		case <-.Done():
			.Reset()
			log.Debugf("rejecting request without dialing: %s %p ", , .Err())
			return EventDialRequestCompleted{Error: .Err(), DialDataRequired: true, DialedAddr: }
		case <-.C:
		}
	}

	 := .dialBack(, .Conn().RemotePeer(), , )
	 = pb.Message{
		Msg: &pb.Message_DialResponse{
			DialResponse: &pb.DialResponse{
				Status:     pb.DialResponse_OK,
				DialStatus: ,
				AddrIdx:    uint32(),
			},
		},
	}
	if  := .WriteMsg(&);  != nil {
		.Reset()
		log.Debugf("failed to write response to %s: %s", , )
		return EventDialRequestCompleted{
			ResponseStatus:   pb.DialResponse_OK,
			DialStatus:       ,
			Error:            fmt.Errorf("write failed: %w", ),
			DialDataRequired: ,
			DialedAddr:       ,
		}
	}
	return EventDialRequestCompleted{
		ResponseStatus:   pb.DialResponse_OK,
		DialStatus:       ,
		Error:            nil,
		DialDataRequired: ,
		DialedAddr:       ,
	}
}

// getDialData gets data from the client for dialing the address
func getDialData( pbio.Writer,  network.Stream,  *pb.Message,  int) error {
	 := minHandshakeSizeBytes + rand.Intn(maxHandshakeSizeBytes-minHandshakeSizeBytes)
	* = pb.Message{
		Msg: &pb.Message_DialDataRequest{
			DialDataRequest: &pb.DialDataRequest{
				AddrIdx:  uint32(),
				NumBytes: uint64(),
			},
		},
	}
	if  := .WriteMsg();  != nil {
		return fmt.Errorf("dial data write: %w", )
	}
	// pbio.Reader that we used so far on this stream is buffered. But at this point
	// there is nothing unread on the stream. So it is safe to use the raw stream to
	// read, reducing allocations.
	return readDialData(, )
}

func readDialData( int,  io.Reader) error {
	 := &msgReader{R: , Buf: pool.Get(maxMsgSize)}
	defer pool.Put(.Buf)
	for  := ;  > 0; {
		,  := .ReadMsg()
		if  != nil {
			return fmt.Errorf("dial data read: %w", )
		}
		// protobuf format is:
		// (oneof dialDataResponse:<fieldTag><len varint>)(dial data:<fieldTag><len varint><bytes>)
		 := len()
		 -= 2 // fieldTag + varint first byte
		if  > 127 {
			 -= 1 // varint second byte
		}
		 -= 2 // second fieldTag + varint first byte
		if  > 127 {
			 -= 1 // varint second byte
		}
		if  > 0 {
			 -= 
		}
		// Check if the peer is not sending too little data forcing us to just do a lot of compute
		if  < 100 &&  > 0 {
			return fmt.Errorf("dial data msg too small: %d", )
		}
	}
	return nil
}

func ( *server) ( context.Context,  peer.ID,  ma.Multiaddr,  uint64) pb.DialStatus {
	,  := context.WithTimeout(, dialBackDialTimeout)
	 = network.WithForceDirectDial(, "autonatv2")
	.dialerHost.Peerstore().AddAddr(, , peerstore.TempAddrTTL)
	defer func() {
		()
		.dialerHost.Network().ClosePeer()
		.dialerHost.Peerstore().ClearAddrs()
		.dialerHost.Peerstore().RemovePeer()
	}()

	 := .dialerHost.Connect(, peer.AddrInfo{ID: })
	if  != nil {
		return pb.DialStatus_E_DIAL_ERROR
	}

	,  := .dialerHost.NewStream(, , DialBackProtocol)
	if  != nil {
		return pb.DialStatus_E_DIAL_BACK_ERROR
	}

	defer .Close()
	.SetDeadline(.now().Add(dialBackStreamTimeout))

	 := pbio.NewDelimitedWriter()
	if  := .WriteMsg(&pb.DialBack{Nonce: });  != nil {
		.Reset()
		return pb.DialStatus_E_DIAL_BACK_ERROR
	}

	// Since the underlying connection is on a separate dialer, it'll be closed after this
	// function returns. Connection close will drop all the queued writes. To ensure message
	// delivery, do a CloseWrite and read a byte from the stream. The peer actually sends a
	// response of type DialBackResponse but we only care about the fact that the DialBack
	// message has reached the peer. So we ignore that message on the read side.
	.CloseWrite()
	.SetDeadline(.now().Add(5 * time.Second)) // 5 is a magic number
	 := make([]byte, 1)                         // Read 1 byte here because 0 len reads are free to return (0, nil) immediately
	.Read()

	return pb.DialStatus_OK
}

// rateLimiter implements a sliding window rate limit of requests per minute. It allows 1 concurrent request
// per peer. It rate limits requests globally, at a peer level and depending on whether it requires dial data.
type rateLimiter struct {
	// PerPeerRPM is the rate limit per peer
	PerPeerRPM int
	// RPM is the global rate limit
	RPM int
	// DialDataRPM is the rate limit for requests that require dial data
	DialDataRPM int
	// MaxConcurrentRequestsPerPeer is the maximum number of concurrent requests per peer
	MaxConcurrentRequestsPerPeer int

	mu           sync.Mutex
	closed       bool
	reqs         []entry
	peerReqs     map[peer.ID][]time.Time
	dialDataReqs []time.Time
	// inProgressReqs tracks in progress requests. This is used to limit multiple
	// concurrent requests by the same peer.
	inProgressReqs map[peer.ID]int

	now func() time.Time // for tests
}

type entry struct {
	PeerID peer.ID
	Time   time.Time
}

func ( *rateLimiter) () {
	if .peerReqs == nil {
		.peerReqs = make(map[peer.ID][]time.Time)
		.inProgressReqs = make(map[peer.ID]int)
	}
}

func ( *rateLimiter) ( peer.ID) bool {
	.mu.Lock()
	defer .mu.Unlock()
	if .closed {
		return false
	}
	.init()
	 := .now()
	.cleanup()

	if .inProgressReqs[] >= .MaxConcurrentRequestsPerPeer {
		return false
	}
	if len(.reqs) >= .RPM || len(.peerReqs[]) >= .PerPeerRPM {
		return false
	}

	.inProgressReqs[]++
	.reqs = append(.reqs, entry{PeerID: , Time: })
	.peerReqs[] = append(.peerReqs[], )
	return true
}

func ( *rateLimiter) () bool {
	.mu.Lock()
	defer .mu.Unlock()
	if .closed {
		return false
	}
	.init()
	 := .now()
	.cleanup()
	if len(.dialDataReqs) >= .DialDataRPM {
		return false
	}
	.dialDataReqs = append(.dialDataReqs, )
	return true
}

// cleanup removes stale requests.
//
// This is fast enough in rate limited cases and the state is small enough to
// clean up quickly when blocking requests.
func ( *rateLimiter) ( time.Time) {
	 := len(.reqs)
	for ,  := range .reqs {
		if .Sub(.Time) >= time.Minute {
			 := len(.peerReqs[.PeerID])
			for ,  := range .peerReqs[.PeerID] {
				if .Sub() < time.Minute {
					 = 
					break
				}
			}
			.peerReqs[.PeerID] = .peerReqs[.PeerID][:]
			if len(.peerReqs[.PeerID]) == 0 {
				delete(.peerReqs, .PeerID)
			}
		} else {
			 = 
			break
		}
	}
	.reqs = .reqs[:]

	 = len(.dialDataReqs)
	for ,  := range .dialDataReqs {
		if .Sub() < time.Minute {
			 = 
			break
		}
	}
	.dialDataReqs = .dialDataReqs[:]
}

func ( *rateLimiter) ( peer.ID) {
	.mu.Lock()
	defer .mu.Unlock()
	.inProgressReqs[]--
	if .inProgressReqs[] <= 0 {
		delete(.inProgressReqs, )
		if .inProgressReqs[] < 0 {
			log.Errorf("BUG: negative in progress requests for peer %s", )
		}
	}
}

func ( *rateLimiter) () {
	.mu.Lock()
	defer .mu.Unlock()
	.closed = true
	.peerReqs = nil
	.inProgressReqs = nil
	.dialDataReqs = nil
}

// amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed
// IP address is different from the dial back IP address
func amplificationAttackPrevention(,  ma.Multiaddr) bool {
	,  := manet.ToIP()
	if  != nil {
		return true
	}
	,  := manet.ToIP() // can be dns addr
	if  != nil {
		return true
	}
	return !.Equal()
}