// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

// Package netctx wraps common net interfaces using context.Context.
package netctx import ( ) // ErrClosing is returned on Write to closed connection. var ErrClosing = errors.New("use of closed network connection") // Reader is an interface for context controlled reader. type Reader interface { ReadContext(context.Context, []byte) (int, error) } // Writer is an interface for context controlled writer. type Writer interface { WriteContext(context.Context, []byte) (int, error) } // ReadWriter is a composite of ReadWriter. type ReadWriter interface { Reader Writer } // Conn is a wrapper of net.Conn using context.Context. type Conn interface { Reader Writer io.Closer LocalAddr() net.Addr RemoteAddr() net.Addr Conn() net.Conn } type conn struct { nextConn net.Conn closed chan struct{} closeOnce sync.Once readMu sync.Mutex writeMu sync.Mutex } var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals // NewConn creates a new Conn wrapping given net.Conn. func ( net.Conn) Conn { := &conn{ nextConn: , closed: make(chan struct{}), } return } // ReadContext reads data from the connection. // Unlike net.Conn.Read(), the provided context is used to control timeout. func ( *conn) ( context.Context, []byte) (int, error) { .readMu.Lock() defer .readMu.Unlock() select { case <-.closed: return 0, net.ErrClosed default: } := make(chan struct{}) var sync.WaitGroup var atomic.Value .Add(1) go func() { defer .Done() select { case <-.Done(): // context canceled if := .nextConn.SetReadDeadline(veryOld); != nil { .Store() return } <- if := .nextConn.SetReadDeadline(time.Time{}); != nil { .Store() } case <-: } }() , := .nextConn.Read() close() .Wait() if := .Err(); != nil && == 0 { = } if , := .Load().(error); && == nil && != nil { = } return , } // WriteContext writes data to the connection. // Unlike net.Conn.Write(), the provided context is used to control timeout. func ( *conn) ( context.Context, []byte) (int, error) { .writeMu.Lock() defer .writeMu.Unlock() select { case <-.closed: return 0, ErrClosing default: } := make(chan struct{}) var sync.WaitGroup var atomic.Value .Add(1) go func() { defer .Done() select { case <-.Done(): // context canceled if := .nextConn.SetWriteDeadline(veryOld); != nil { .Store() return } <- if := .nextConn.SetWriteDeadline(time.Time{}); != nil { .Store() } case <-: } }() , := .nextConn.Write() close() .Wait() if := .Err(); != nil && == 0 { = } if , := .Load().(error); && == nil && != nil { = } return , } // Close closes the connection. // Any blocked ReadContext or WriteContext operations will be unblocked and // return errors. func ( *conn) () error { := .nextConn.Close() .closeOnce.Do(func() { .writeMu.Lock() .readMu.Lock() close(.closed) .readMu.Unlock() .writeMu.Unlock() }) return } // LocalAddr returns the local network address, if known. func ( *conn) () net.Addr { return .nextConn.LocalAddr() } // LocalAddr returns the local network address, if known. func ( *conn) () net.Addr { return .nextConn.RemoteAddr() } // Conn returns the underlying net.Conn. func ( *conn) () net.Conn { return .nextConn }