package madns

import (
	
	
	

	
	ma 
)

var (
	dnsaddrProtocol = ma.ProtocolWithCode(ma.P_DNSADDR)
	dns4Protocol    = ma.ProtocolWithCode(ma.P_DNS4)
	dns6Protocol    = ma.ProtocolWithCode(ma.P_DNS6)
	dnsProtocol     = ma.ProtocolWithCode(ma.P_DNS)
)

var (
	ResolvableProtocols = []ma.Protocol{dnsaddrProtocol, dns4Protocol, dns6Protocol, dnsProtocol}
	DefaultResolver     = &Resolver{def: net.DefaultResolver}
)

const maxResolvedAddrs = 100

const dnsaddrTXTPrefix = "dnsaddr="

// BasicResolver is a low level interface for DNS resolution
type BasicResolver interface {
	LookupIPAddr(context.Context, string) ([]net.IPAddr, error)
	LookupTXT(context.Context, string) ([]string, error)
}

// Resolver is an object capable of resolving dns multiaddrs by using one or more BasicResolvers;
// it supports custom per domain/TLD resolvers.
// It also implements the BasicResolver interface so that it can act as a custom per domain/TLD
// resolver.
type Resolver struct {
	def    BasicResolver
	custom map[string]BasicResolver
}

var _ BasicResolver = (*Resolver)(nil)

// NewResolver creates a new Resolver instance with the specified options
func ( ...Option) (*Resolver, error) {
	 := &Resolver{def: net.DefaultResolver}
	for ,  := range  {
		 := ()
		if  != nil {
			return nil, 
		}
	}

	return , nil
}

type Option func(*Resolver) error

// WithDefaultResolver is an option that specifies the default basic resolver,
// which resolves any TLD that doesn't have a custom resolver.
// Defaults to net.DefaultResolver
func ( BasicResolver) Option {
	return func( *Resolver) error {
		.def = 
		return nil
	}
}

// WithDomainResolver specifies a custom resolver for a domain/TLD.
// Custom resolver selection matches domains left to right, with more specific resolvers
// superseding generic ones.
func ( string,  BasicResolver) Option {
	return func( *Resolver) error {
		if .custom == nil {
			.custom = make(map[string]BasicResolver)
		}
		 := dns.Fqdn()
		.custom[] = 
		return nil
	}
}

func ( *Resolver) ( string) BasicResolver {
	 := dns.Fqdn()

	// we match left-to-right, with more specific resolvers superseding generic ones.
	// So for a domain a.b.c, we will try a.b,c, b.c, c, and fallback to the default if
	// there is no match
	,  := .custom[]
	if  {
		return 
	}

	for  := strings.Index(, ".");  != -1;  = strings.Index(, ".") {
		 = [+1:]
		if  == "" {
			// the . is the default resolver
			break
		}

		,  = .custom[]
		if  {
			return 
		}
	}

	return .def
}

// Resolve resolves a DNS multiaddr. It will only resolve the first DNS component in the multiaddr.
// If you need to resolve multiple DNS components, you may call this function again with each returned address.
func ( *Resolver) ( context.Context,  ma.Multiaddr) ([]ma.Multiaddr, error) {
	if  == nil {
		return nil, nil
	}

	// Find the next dns component.
	,  := ma.SplitFunc(, func( ma.Component) bool {
		switch .Protocol().Code {
		case dnsProtocol.Code, dns4Protocol.Code, dns6Protocol.Code, dnsaddrProtocol.Code:
			return true
		default:
			return false
		}
	})

	// If the rest is empty, we've hit the end (there _was_ no dns component).
	if  == nil {
		return []ma.Multiaddr{}, nil
	}

	// split off the dns component.
	,  := ma.SplitFirst()

	 := .Protocol()
	 := .Value()
	 := .getResolver()

	// resolve the dns component
	var  []ma.Multiaddr
	switch .Code {
	case dns4Protocol.Code, dns6Protocol.Code, dnsProtocol.Code:
		// The dns, dns4, and dns6 resolver simply resolves each
		// dns* component into an ipv4/ipv6 address.

		 := .Code == dns4Protocol.Code
		 := .Code == dns6Protocol.Code

		// XXX: Unfortunately, go does a pretty terrible job of
		// differentiating between IPv6 and IPv4. A v4-in-v6
		// AAAA record will _look_ like an A record to us and
		// there's nothing we can do about that.
		,  := .LookupIPAddr(, )
		if  != nil {
			return nil, 
		}

		// Convert each DNS record into a multiaddr. If the
		// protocol is dns4, throw away any IPv6 addresses. If
		// the protocol is dns6, throw away any IPv4 addresses.

		for ,  := range  {
			var (
				 ma.Multiaddr
				    error
			)
			 := .IP.To4()
			if  == nil {
				if  {
					continue
				}
				,  = ma.NewMultiaddr("/ip6/" + .IP.String())
			} else {
				if  {
					continue
				}
				,  = ma.NewMultiaddr("/ip4/" + .String())
			}
			if  != nil {
				return nil, 
			}
			 = append(, )
		}
	case dnsaddrProtocol.Code:
		// The dnsaddr resolver is a bit more complicated. We:
		//
		// 1. Lookup the dnsaddr txt record on _dnsaddr.DOMAIN.TLD
		// 2. Take everything _after_ the `/dnsaddr/DOMAIN.TLD`
		//    part of the multiaddr.
		// 3. Find the dnsaddr records (if any) with suffixes
		//    matching the result of step 2.

		// First, lookup the TXT record
		,  := .LookupTXT(, "_dnsaddr."+)
		if  != nil {
			return nil, 
		}

		// Then, calculate the length of the suffix we're
		// looking for.
		 := 0
		if  != nil {
			 = addrLen()
		}

		for ,  := range  {
			// Ignore non dnsaddr TXT records.
			if !strings.HasPrefix(, dnsaddrTXTPrefix) {
				continue
			}

			// Extract and decode the multiaddr.
			,  := ma.NewMultiaddr([len(dnsaddrTXTPrefix):])
			if  != nil {
				// discard multiaddrs we don't understand.
				// XXX: Is this right? It's the best we
				// can do for now, really.
				continue
			}

			// If we have a suffix to match on.
			if  != nil {
				// Make sure the new address is at least
				// as long as the suffix we're looking
				// for.
				 := addrLen()
				if  <  {
					// not long enough.
					continue
				}

				// Matches everything after the /dnsaddr/... with the end of the
				// dnsaddr record:
				//
				// v----------rmlen-----------------v
				// /ip4/1.2.3.4/tcp/1234/p2p/QmFoobar
				//                      /p2p/QmFoobar
				// ^--(rmlen - length)--^---length--^
				if !.Equal(offset(, -)) {
					continue
				}
			}

			// remove the suffix from the multiaddr, we'll add it back at the end.
			if  != nil {
				 = .Decapsulate()
			}
			if  == nil {
				continue
			}
			 = append(, )
		}
	default:
		panic("unreachable")
	}

	if len() == 0 {
		return nil, nil
	}

	if len() > maxResolvedAddrs {
		 = [:maxResolvedAddrs]
	}

	if  != nil {
		for ,  := range  {
			[] = .Encapsulate()
		}
	}
	if  != nil {
		for ,  := range  {
			[] = .Encapsulate()
		}
	}

	return , nil
}

func ( *Resolver) ( context.Context,  string) ([]net.IPAddr, error) {
	return .getResolver().LookupIPAddr(, )
}

func ( *Resolver) ( context.Context,  string) ([]string, error) {
	return .getResolver().LookupTXT(, )
}