package dns
import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"io"
"net"
"strings"
"sync"
"time"
)
const maxTCPQueries = 128
var aLongTimeAgo = time .Unix (1 , 0 )
type Handler interface {
ServeDNS (w ResponseWriter , r *Msg )
}
type HandlerFunc func (ResponseWriter , *Msg )
func (f HandlerFunc ) ServeDNS (w ResponseWriter , r *Msg ) {
f (w , r )
}
type ResponseWriter interface {
LocalAddr () net .Addr
RemoteAddr () net .Addr
WriteMsg (*Msg ) error
Write ([]byte ) (int , error )
Close () error
TsigStatus () error
TsigTimersOnly (bool )
Hijack ()
}
type ConnectionStater interface {
ConnectionState () *tls .ConnectionState
}
type response struct {
closed bool
hijacked bool
tsigTimersOnly bool
tsigStatus error
tsigRequestMAC string
tsigProvider TsigProvider
udp net .PacketConn
tcp net .Conn
udpSession *SessionUDP
pcSession net .Addr
writer Writer
}
func handleRefused(w ResponseWriter , r *Msg ) {
m := new (Msg )
m .SetRcode (r , RcodeRefused )
w .WriteMsg (m )
}
func HandleFailed (w ResponseWriter , r *Msg ) {
m := new (Msg )
m .SetRcode (r , RcodeServerFailure )
w .WriteMsg (m )
}
func ListenAndServe (addr string , network string , handler Handler ) error {
server := &Server {Addr : addr , Net : network , Handler : handler }
return server .ListenAndServe ()
}
func ListenAndServeTLS (addr , certFile , keyFile string , handler Handler ) error {
cert , err := tls .LoadX509KeyPair (certFile , keyFile )
if err != nil {
return err
}
config := tls .Config {
Certificates : []tls .Certificate {cert },
}
server := &Server {
Addr : addr ,
Net : "tcp-tls" ,
TLSConfig : &config ,
Handler : handler ,
}
return server .ListenAndServe ()
}
func ActivateAndServe (l net .Listener , p net .PacketConn , handler Handler ) error {
server := &Server {Listener : l , PacketConn : p , Handler : handler }
return server .ActivateAndServe ()
}
type Writer interface {
io .Writer
}
type Reader interface {
ReadTCP (conn net .Conn , timeout time .Duration ) ([]byte , error )
ReadUDP (conn *net .UDPConn , timeout time .Duration ) ([]byte , *SessionUDP , error )
}
type PacketConnReader interface {
Reader
ReadPacketConn (conn net .PacketConn , timeout time .Duration ) ([]byte , net .Addr , error )
}
type defaultReader struct {
*Server
}
var _ PacketConnReader = defaultReader {}
func (dr defaultReader ) ReadTCP (conn net .Conn , timeout time .Duration ) ([]byte , error ) {
return dr .readTCP (conn , timeout )
}
func (dr defaultReader ) ReadUDP (conn *net .UDPConn , timeout time .Duration ) ([]byte , *SessionUDP , error ) {
return dr .readUDP (conn , timeout )
}
func (dr defaultReader ) ReadPacketConn (conn net .PacketConn , timeout time .Duration ) ([]byte , net .Addr , error ) {
return dr .readPacketConn (conn , timeout )
}
type DecorateReader func (Reader ) Reader
type DecorateWriter func (Writer ) Writer
type MsgInvalidFunc func (m []byte , err error )
func DefaultMsgInvalidFunc (m []byte , err error ) {}
type Server struct {
Addr string
Net string
Listener net .Listener
TLSConfig *tls .Config
PacketConn net .PacketConn
Handler Handler
UDPSize int
ReadTimeout time .Duration
WriteTimeout time .Duration
IdleTimeout func () time .Duration
TsigProvider TsigProvider
TsigSecret map [string ]string
NotifyStartedFunc func ()
DecorateReader DecorateReader
DecorateWriter DecorateWriter
MaxTCPQueries int
ReusePort bool
ReuseAddr bool
MsgAcceptFunc MsgAcceptFunc
MsgInvalidFunc MsgInvalidFunc
lock sync .RWMutex
started bool
shutdown chan struct {}
conns map [net .Conn ]struct {}
udpPool sync .Pool
}
func (srv *Server ) tsigProvider () TsigProvider {
if srv .TsigProvider != nil {
return srv .TsigProvider
}
if srv .TsigSecret != nil {
return tsigSecretProvider (srv .TsigSecret )
}
return nil
}
func (srv *Server ) isStarted () bool {
srv .lock .RLock ()
started := srv .started
srv .lock .RUnlock ()
return started
}
func makeUDPBuffer(size int ) func () interface {} {
return func () interface {} {
return make ([]byte , size )
}
}
func (srv *Server ) init () {
srv .shutdown = make (chan struct {})
srv .conns = make (map [net .Conn ]struct {})
if srv .UDPSize == 0 {
srv .UDPSize = MinMsgSize
}
if srv .MsgAcceptFunc == nil {
srv .MsgAcceptFunc = DefaultMsgAcceptFunc
}
if srv .MsgInvalidFunc == nil {
srv .MsgInvalidFunc = DefaultMsgInvalidFunc
}
if srv .Handler == nil {
srv .Handler = DefaultServeMux
}
srv .udpPool .New = makeUDPBuffer (srv .UDPSize )
}
func unlockOnce(l sync .Locker ) func () {
var once sync .Once
return func () { once .Do (l .Unlock ) }
}
func (srv *Server ) ListenAndServe () error {
unlock := unlockOnce (&srv .lock )
srv .lock .Lock ()
defer unlock ()
if srv .started {
return &Error {err : "server already started" }
}
addr := srv .Addr
if addr == "" {
addr = ":domain"
}
srv .init ()
switch srv .Net {
case "tcp" , "tcp4" , "tcp6" :
l , err := listenTCP (srv .Net , addr , srv .ReusePort , srv .ReuseAddr )
if err != nil {
return err
}
srv .Listener = l
srv .started = true
unlock ()
return srv .serveTCP (l )
case "tcp-tls" , "tcp4-tls" , "tcp6-tls" :
if srv .TLSConfig == nil || (len (srv .TLSConfig .Certificates ) == 0 && srv .TLSConfig .GetCertificate == nil ) {
return errors .New ("dns: neither Certificates nor GetCertificate set in Config" )
}
network := strings .TrimSuffix (srv .Net , "-tls" )
l , err := listenTCP (network , addr , srv .ReusePort , srv .ReuseAddr )
if err != nil {
return err
}
l = tls .NewListener (l , srv .TLSConfig )
srv .Listener = l
srv .started = true
unlock ()
return srv .serveTCP (l )
case "udp" , "udp4" , "udp6" :
l , err := listenUDP (srv .Net , addr , srv .ReusePort , srv .ReuseAddr )
if err != nil {
return err
}
u := l .(*net .UDPConn )
if e := setUDPSocketOptions (u ); e != nil {
u .Close ()
return e
}
srv .PacketConn = l
srv .started = true
unlock ()
return srv .serveUDP (u )
}
return &Error {err : "bad network" }
}
func (srv *Server ) ActivateAndServe () error {
unlock := unlockOnce (&srv .lock )
srv .lock .Lock ()
defer unlock ()
if srv .started {
return &Error {err : "server already started" }
}
srv .init ()
if srv .PacketConn != nil {
if t , ok := srv .PacketConn .(*net .UDPConn ); ok && t != nil {
if e := setUDPSocketOptions (t ); e != nil {
return e
}
}
srv .started = true
unlock ()
return srv .serveUDP (srv .PacketConn )
}
if srv .Listener != nil {
srv .started = true
unlock ()
return srv .serveTCP (srv .Listener )
}
return &Error {err : "bad listeners" }
}
func (srv *Server ) Shutdown () error {
return srv .ShutdownContext (context .Background ())
}
func (srv *Server ) ShutdownContext (ctx context .Context ) error {
srv .lock .Lock ()
if !srv .started {
srv .lock .Unlock ()
return &Error {err : "server not started" }
}
srv .started = false
if srv .PacketConn != nil {
srv .PacketConn .SetReadDeadline (aLongTimeAgo )
}
if srv .Listener != nil {
srv .Listener .Close ()
}
for rw := range srv .conns {
rw .SetReadDeadline (aLongTimeAgo )
}
srv .lock .Unlock ()
if testShutdownNotify != nil {
testShutdownNotify .Broadcast ()
}
var ctxErr error
select {
case <- srv .shutdown :
case <- ctx .Done ():
ctxErr = ctx .Err ()
}
if srv .PacketConn != nil {
srv .PacketConn .Close ()
}
return ctxErr
}
var testShutdownNotify *sync .Cond
func (srv *Server ) getReadTimeout () time .Duration {
if srv .ReadTimeout != 0 {
return srv .ReadTimeout
}
return dnsTimeout
}
func (srv *Server ) serveTCP (l net .Listener ) error {
defer l .Close ()
if srv .NotifyStartedFunc != nil {
srv .NotifyStartedFunc ()
}
var wg sync .WaitGroup
defer func () {
wg .Wait ()
close (srv .shutdown )
}()
for srv .isStarted () {
rw , err := l .Accept ()
if err != nil {
if !srv .isStarted () {
return nil
}
if neterr , ok := err .(net .Error ); ok && neterr .Temporary () {
continue
}
return err
}
srv .lock .Lock ()
srv .conns [rw ] = struct {}{}
srv .lock .Unlock ()
wg .Add (1 )
go srv .serveTCPConn (&wg , rw )
}
return nil
}
func (srv *Server ) serveUDP (l net .PacketConn ) error {
defer l .Close ()
reader := Reader (defaultReader {srv })
if srv .DecorateReader != nil {
reader = srv .DecorateReader (reader )
}
lUDP , isUDP := l .(*net .UDPConn )
readerPC , canPacketConn := reader .(PacketConnReader )
if !isUDP && !canPacketConn {
return &Error {err : "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn" }
}
if srv .NotifyStartedFunc != nil {
srv .NotifyStartedFunc ()
}
var wg sync .WaitGroup
defer func () {
wg .Wait ()
close (srv .shutdown )
}()
rtimeout := srv .getReadTimeout ()
for srv .isStarted () {
var (
m []byte
sPC net .Addr
sUDP *SessionUDP
err error
)
if isUDP {
m , sUDP , err = reader .ReadUDP (lUDP , rtimeout )
} else {
m , sPC , err = readerPC .ReadPacketConn (l , rtimeout )
}
if err != nil {
if !srv .isStarted () {
return nil
}
if netErr , ok := err .(net .Error ); ok && netErr .Temporary () {
continue
}
return err
}
if len (m ) < headerSize {
if cap (m ) == srv .UDPSize {
srv .udpPool .Put (m [:srv .UDPSize ])
}
srv .MsgInvalidFunc (m , ErrShortRead )
continue
}
wg .Add (1 )
go srv .serveUDPPacket (&wg , m , l , sUDP , sPC )
}
return nil
}
func (srv *Server ) serveTCPConn (wg *sync .WaitGroup , rw net .Conn ) {
w := &response {tsigProvider : srv .tsigProvider (), tcp : rw }
if srv .DecorateWriter != nil {
w .writer = srv .DecorateWriter (w )
} else {
w .writer = w
}
reader := Reader (defaultReader {srv })
if srv .DecorateReader != nil {
reader = srv .DecorateReader (reader )
}
idleTimeout := tcpIdleTimeout
if srv .IdleTimeout != nil {
idleTimeout = srv .IdleTimeout ()
}
timeout := srv .getReadTimeout ()
limit := srv .MaxTCPQueries
if limit == 0 {
limit = maxTCPQueries
}
for q := 0 ; (q < limit || limit == -1 ) && srv .isStarted (); q ++ {
m , err := reader .ReadTCP (w .tcp , timeout )
if err != nil {
break
}
srv .serveDNS (m , w )
if w .closed {
break
}
if w .hijacked {
break
}
timeout = idleTimeout
}
if !w .hijacked {
w .Close ()
}
srv .lock .Lock ()
delete (srv .conns , w .tcp )
srv .lock .Unlock ()
wg .Done ()
}
func (srv *Server ) serveUDPPacket (wg *sync .WaitGroup , m []byte , u net .PacketConn , udpSession *SessionUDP , pcSession net .Addr ) {
w := &response {tsigProvider : srv .tsigProvider (), udp : u , udpSession : udpSession , pcSession : pcSession }
if srv .DecorateWriter != nil {
w .writer = srv .DecorateWriter (w )
} else {
w .writer = w
}
srv .serveDNS (m , w )
wg .Done ()
}
func (srv *Server ) serveDNS (m []byte , w *response ) {
dh , off , err := unpackMsgHdr (m , 0 )
if err != nil {
srv .MsgInvalidFunc (m , err )
return
}
req := new (Msg )
req .setHdr (dh )
switch action := srv .MsgAcceptFunc (dh ); action {
case MsgAccept :
err := req .unpack (dh , m , off )
if err == nil {
break
}
srv .MsgInvalidFunc (m , err )
fallthrough
case MsgReject , MsgRejectNotImplemented :
opcode := req .Opcode
req .SetRcodeFormatError (req )
req .Zero = false
if action == MsgRejectNotImplemented {
req .Opcode = opcode
req .Rcode = RcodeNotImplemented
}
req .Ns , req .Answer , req .Extra = nil , nil , nil
w .WriteMsg (req )
fallthrough
case MsgIgnore :
if w .udp != nil && cap (m ) == srv .UDPSize {
srv .udpPool .Put (m [:srv .UDPSize ])
}
return
}
w .tsigStatus = nil
if w .tsigProvider != nil {
if t := req .IsTsig (); t != nil {
w .tsigStatus = TsigVerifyWithProvider (m , w .tsigProvider , "" , false )
w .tsigTimersOnly = false
w .tsigRequestMAC = t .MAC
}
}
if w .udp != nil && cap (m ) == srv .UDPSize {
srv .udpPool .Put (m [:srv .UDPSize ])
}
srv .Handler .ServeDNS (w , req )
}
func (srv *Server ) readTCP (conn net .Conn , timeout time .Duration ) ([]byte , error ) {
srv .lock .RLock ()
if srv .started {
conn .SetReadDeadline (time .Now ().Add (timeout ))
}
srv .lock .RUnlock ()
var length uint16
if err := binary .Read (conn , binary .BigEndian , &length ); err != nil {
return nil , err
}
m := make ([]byte , length )
if _ , err := io .ReadFull (conn , m ); err != nil {
return nil , err
}
return m , nil
}
func (srv *Server ) readUDP (conn *net .UDPConn , timeout time .Duration ) ([]byte , *SessionUDP , error ) {
srv .lock .RLock ()
if srv .started {
conn .SetReadDeadline (time .Now ().Add (timeout ))
}
srv .lock .RUnlock ()
m := srv .udpPool .Get ().([]byte )
n , s , err := ReadFromSessionUDP (conn , m )
if err != nil {
srv .udpPool .Put (m )
return nil , nil , err
}
m = m [:n ]
return m , s , nil
}
func (srv *Server ) readPacketConn (conn net .PacketConn , timeout time .Duration ) ([]byte , net .Addr , error ) {
srv .lock .RLock ()
if srv .started {
conn .SetReadDeadline (time .Now ().Add (timeout ))
}
srv .lock .RUnlock ()
m := srv .udpPool .Get ().([]byte )
n , addr , err := conn .ReadFrom (m )
if err != nil {
srv .udpPool .Put (m )
return nil , nil , err
}
m = m [:n ]
return m , addr , nil
}
func (w *response ) WriteMsg (m *Msg ) (err error ) {
if w .closed {
return &Error {err : "WriteMsg called after Close" }
}
var data []byte
if w .tsigProvider != nil {
if t := m .IsTsig (); t != nil {
data , w .tsigRequestMAC , err = TsigGenerateWithProvider (m , w .tsigProvider , w .tsigRequestMAC , w .tsigTimersOnly )
if err != nil {
return err
}
_, err = w .writer .Write (data )
return err
}
}
data , err = m .Pack ()
if err != nil {
return err
}
_, err = w .writer .Write (data )
return err
}
func (w *response ) Write (m []byte ) (int , error ) {
if w .closed {
return 0 , &Error {err : "Write called after Close" }
}
switch {
case w .udp != nil :
if u , ok := w .udp .(*net .UDPConn ); ok {
return WriteToSessionUDP (u , m , w .udpSession )
}
return w .udp .WriteTo (m , w .pcSession )
case w .tcp != nil :
if len (m ) > MaxMsgSize {
return 0 , &Error {err : "message too large" }
}
msg := make ([]byte , 2 +len (m ))
binary .BigEndian .PutUint16 (msg , uint16 (len (m )))
copy (msg [2 :], m )
return w .tcp .Write (msg )
default :
panic ("dns: internal error: udp and tcp both nil" )
}
}
func (w *response ) LocalAddr () net .Addr {
switch {
case w .udp != nil :
return w .udp .LocalAddr ()
case w .tcp != nil :
return w .tcp .LocalAddr ()
default :
panic ("dns: internal error: udp and tcp both nil" )
}
}
func (w *response ) RemoteAddr () net .Addr {
switch {
case w .udpSession != nil :
return w .udpSession .RemoteAddr ()
case w .pcSession != nil :
return w .pcSession
case w .tcp != nil :
return w .tcp .RemoteAddr ()
default :
panic ("dns: internal error: udpSession, pcSession and tcp are all nil" )
}
}
func (w *response ) TsigStatus () error { return w .tsigStatus }
func (w *response ) TsigTimersOnly (b bool ) { w .tsigTimersOnly = b }
func (w *response ) Hijack () { w .hijacked = true }
func (w *response ) Close () error {
if w .closed {
return &Error {err : "connection already closed" }
}
w .closed = true
switch {
case w .udp != nil :
return nil
case w .tcp != nil :
return w .tcp .Close ()
default :
panic ("dns: internal error: udp and tcp both nil" )
}
}
func (w *response ) ConnectionState () *tls .ConnectionState {
type tlsConnectionStater interface {
ConnectionState () tls .ConnectionState
}
if v , ok := w .tcp .(tlsConnectionStater ); ok {
t := v .ConnectionState ()
return &t
}
return nil
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .