package redis
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
)
type PubSub struct {
opt *Options
newConn func (ctx context .Context , channels []string ) (*pool .Conn , error )
closeConn func (*pool .Conn ) error
mu sync .Mutex
cn *pool .Conn
channels map [string ]struct {}
patterns map [string ]struct {}
schannels map [string ]struct {}
closed bool
exit chan struct {}
cmd *Cmd
chOnce sync .Once
msgCh *channel
allCh *channel
}
func (c *PubSub ) init () {
c .exit = make (chan struct {})
}
func (c *PubSub ) String () string {
channels := mapKeys (c .channels )
channels = append (channels , mapKeys (c .patterns )...)
channels = append (channels , mapKeys (c .schannels )...)
return fmt .Sprintf ("PubSub(%s)" , strings .Join (channels , ", " ))
}
func (c *PubSub ) connWithLock (ctx context .Context ) (*pool .Conn , error ) {
c .mu .Lock ()
cn , err := c .conn (ctx , nil )
c .mu .Unlock ()
return cn , err
}
func (c *PubSub ) conn (ctx context .Context , newChannels []string ) (*pool .Conn , error ) {
if c .closed {
return nil , pool .ErrClosed
}
if c .cn != nil {
return c .cn , nil
}
channels := mapKeys (c .channels )
channels = append (channels , newChannels ...)
cn , err := c .newConn (ctx , channels )
if err != nil {
return nil , err
}
if err := c .resubscribe (ctx , cn ); err != nil {
_ = c .closeConn (cn )
return nil , err
}
c .cn = cn
return cn , nil
}
func (c *PubSub ) writeCmd (ctx context .Context , cn *pool .Conn , cmd Cmder ) error {
return cn .WithWriter (context .Background (), c .opt .WriteTimeout , func (wr *proto .Writer ) error {
return writeCmd (wr , cmd )
})
}
func (c *PubSub ) resubscribe (ctx context .Context , cn *pool .Conn ) error {
var firstErr error
if len (c .channels ) > 0 {
firstErr = c ._subscribe (ctx , cn , "subscribe" , mapKeys (c .channels ))
}
if len (c .patterns ) > 0 {
err := c ._subscribe (ctx , cn , "psubscribe" , mapKeys (c .patterns ))
if err != nil && firstErr == nil {
firstErr = err
}
}
if len (c .schannels ) > 0 {
err := c ._subscribe (ctx , cn , "ssubscribe" , mapKeys (c .schannels ))
if err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
func mapKeys(m map [string ]struct {}) []string {
s := make ([]string , len (m ))
i := 0
for k := range m {
s [i ] = k
i ++
}
return s
}
func (c *PubSub ) _subscribe (
ctx context .Context , cn *pool .Conn , redisCmd string , channels []string ,
) error {
args := make ([]interface {}, 0 , 1 +len (channels ))
args = append (args , redisCmd )
for _ , channel := range channels {
args = append (args , channel )
}
cmd := NewSliceCmd (ctx , args ...)
return c .writeCmd (ctx , cn , cmd )
}
func (c *PubSub ) releaseConnWithLock (
ctx context .Context ,
cn *pool .Conn ,
err error ,
allowTimeout bool ,
) {
c .mu .Lock ()
c .releaseConn (ctx , cn , err , allowTimeout )
c .mu .Unlock ()
}
func (c *PubSub ) releaseConn (ctx context .Context , cn *pool .Conn , err error , allowTimeout bool ) {
if c .cn != cn {
return
}
if isBadConn (err , allowTimeout , c .opt .Addr ) {
c .reconnect (ctx , err )
}
}
func (c *PubSub ) reconnect (ctx context .Context , reason error ) {
_ = c .closeTheCn (reason )
_, _ = c .conn (ctx , nil )
}
func (c *PubSub ) closeTheCn (reason error ) error {
if c .cn == nil {
return nil
}
if !c .closed {
internal .Logger .Printf (c .getContext (), "redis: discarding bad PubSub connection: %s" , reason )
}
err := c .closeConn (c .cn )
c .cn = nil
return err
}
func (c *PubSub ) Close () error {
c .mu .Lock ()
defer c .mu .Unlock ()
if c .closed {
return pool .ErrClosed
}
c .closed = true
close (c .exit )
return c .closeTheCn (pool .ErrClosed )
}
func (c *PubSub ) Subscribe (ctx context .Context , channels ...string ) error {
c .mu .Lock ()
defer c .mu .Unlock ()
err := c .subscribe (ctx , "subscribe" , channels ...)
if c .channels == nil {
c .channels = make (map [string ]struct {})
}
for _ , s := range channels {
c .channels [s ] = struct {}{}
}
return err
}
func (c *PubSub ) PSubscribe (ctx context .Context , patterns ...string ) error {
c .mu .Lock ()
defer c .mu .Unlock ()
err := c .subscribe (ctx , "psubscribe" , patterns ...)
if c .patterns == nil {
c .patterns = make (map [string ]struct {})
}
for _ , s := range patterns {
c .patterns [s ] = struct {}{}
}
return err
}
func (c *PubSub ) SSubscribe (ctx context .Context , channels ...string ) error {
c .mu .Lock ()
defer c .mu .Unlock ()
err := c .subscribe (ctx , "ssubscribe" , channels ...)
if c .schannels == nil {
c .schannels = make (map [string ]struct {})
}
for _ , s := range channels {
c .schannels [s ] = struct {}{}
}
return err
}
func (c *PubSub ) Unsubscribe (ctx context .Context , channels ...string ) error {
c .mu .Lock ()
defer c .mu .Unlock ()
if len (channels ) > 0 {
for _ , channel := range channels {
delete (c .channels , channel )
}
} else {
for channel := range c .channels {
delete (c .channels , channel )
}
}
err := c .subscribe (ctx , "unsubscribe" , channels ...)
return err
}
func (c *PubSub ) PUnsubscribe (ctx context .Context , patterns ...string ) error {
c .mu .Lock ()
defer c .mu .Unlock ()
if len (patterns ) > 0 {
for _ , pattern := range patterns {
delete (c .patterns , pattern )
}
} else {
for pattern := range c .patterns {
delete (c .patterns , pattern )
}
}
err := c .subscribe (ctx , "punsubscribe" , patterns ...)
return err
}
func (c *PubSub ) SUnsubscribe (ctx context .Context , channels ...string ) error {
c .mu .Lock ()
defer c .mu .Unlock ()
if len (channels ) > 0 {
for _ , channel := range channels {
delete (c .schannels , channel )
}
} else {
for channel := range c .schannels {
delete (c .schannels , channel )
}
}
err := c .subscribe (ctx , "sunsubscribe" , channels ...)
return err
}
func (c *PubSub ) subscribe (ctx context .Context , redisCmd string , channels ...string ) error {
cn , err := c .conn (ctx , channels )
if err != nil {
return err
}
err = c ._subscribe (ctx , cn , redisCmd , channels )
c .releaseConn (ctx , cn , err , false )
return err
}
func (c *PubSub ) Ping (ctx context .Context , payload ...string ) error {
args := []interface {}{"ping" }
if len (payload ) == 1 {
args = append (args , payload [0 ])
}
cmd := NewCmd (ctx , args ...)
c .mu .Lock ()
defer c .mu .Unlock ()
cn , err := c .conn (ctx , nil )
if err != nil {
return err
}
err = c .writeCmd (ctx , cn , cmd )
c .releaseConn (ctx , cn , err , false )
return err
}
type Subscription struct {
Kind string
Channel string
Count int
}
func (m *Subscription ) String () string {
return fmt .Sprintf ("%s: %s" , m .Kind , m .Channel )
}
type Message struct {
Channel string
Pattern string
Payload string
PayloadSlice []string
}
func (m *Message ) String () string {
return fmt .Sprintf ("Message<%s: %s>" , m .Channel , m .Payload )
}
type Pong struct {
Payload string
}
func (p *Pong ) String () string {
if p .Payload != "" {
return fmt .Sprintf ("Pong<%s>" , p .Payload )
}
return "Pong"
}
func (c *PubSub ) newMessage (reply interface {}) (interface {}, error ) {
switch reply := reply .(type ) {
case string :
return &Pong {
Payload : reply ,
}, nil
case []interface {}:
switch kind := reply [0 ].(string ); kind {
case "subscribe" , "unsubscribe" , "psubscribe" , "punsubscribe" , "ssubscribe" , "sunsubscribe" :
channel , _ := reply [1 ].(string )
return &Subscription {
Kind : kind ,
Channel : channel ,
Count : int (reply [2 ].(int64 )),
}, nil
case "message" , "smessage" :
switch payload := reply [2 ].(type ) {
case string :
return &Message {
Channel : reply [1 ].(string ),
Payload : payload ,
}, nil
case []interface {}:
ss := make ([]string , len (payload ))
for i , s := range payload {
ss [i ] = s .(string )
}
return &Message {
Channel : reply [1 ].(string ),
PayloadSlice : ss ,
}, nil
default :
return nil , fmt .Errorf ("redis: unsupported pubsub message payload: %T" , payload )
}
case "pmessage" :
return &Message {
Pattern : reply [1 ].(string ),
Channel : reply [2 ].(string ),
Payload : reply [3 ].(string ),
}, nil
case "pong" :
return &Pong {
Payload : reply [1 ].(string ),
}, nil
default :
return nil , fmt .Errorf ("redis: unsupported pubsub message: %q" , kind )
}
default :
return nil , fmt .Errorf ("redis: unsupported pubsub message: %#v" , reply )
}
}
func (c *PubSub ) ReceiveTimeout (ctx context .Context , timeout time .Duration ) (interface {}, error ) {
if c .cmd == nil {
c .cmd = NewCmd (ctx )
}
cn , err := c .connWithLock (ctx )
if err != nil {
return nil , err
}
err = cn .WithReader (context .Background (), timeout , func (rd *proto .Reader ) error {
return c .cmd .readReply (rd )
})
c .releaseConnWithLock (ctx , cn , err , timeout > 0 )
if err != nil {
return nil , err
}
return c .newMessage (c .cmd .Val ())
}
func (c *PubSub ) Receive (ctx context .Context ) (interface {}, error ) {
return c .ReceiveTimeout (ctx , 0 )
}
func (c *PubSub ) ReceiveMessage (ctx context .Context ) (*Message , error ) {
for {
msg , err := c .Receive (ctx )
if err != nil {
return nil , err
}
switch msg := msg .(type ) {
case *Subscription :
case *Pong :
case *Message :
return msg , nil
default :
err := fmt .Errorf ("redis: unknown message: %T" , msg )
return nil , err
}
}
}
func (c *PubSub ) getContext () context .Context {
if c .cmd != nil {
return c .cmd .ctx
}
return context .Background ()
}
func (c *PubSub ) Channel (opts ...ChannelOption ) <-chan *Message {
c .chOnce .Do (func () {
c .msgCh = newChannel (c , opts ...)
c .msgCh .initMsgChan ()
})
if c .msgCh == nil {
err := fmt .Errorf ("redis: Channel can't be called after ChannelWithSubscriptions" )
panic (err )
}
return c .msgCh .msgCh
}
func (c *PubSub ) ChannelSize (size int ) <-chan *Message {
return c .Channel (WithChannelSize (size ))
}
func (c *PubSub ) ChannelWithSubscriptions (opts ...ChannelOption ) <-chan interface {} {
c .chOnce .Do (func () {
c .allCh = newChannel (c , opts ...)
c .allCh .initAllChan ()
})
if c .allCh == nil {
err := fmt .Errorf ("redis: ChannelWithSubscriptions can't be called after Channel" )
panic (err )
}
return c .allCh .allCh
}
type ChannelOption func (c *channel )
func WithChannelSize (size int ) ChannelOption {
return func (c *channel ) {
c .chanSize = size
}
}
func WithChannelHealthCheckInterval (d time .Duration ) ChannelOption {
return func (c *channel ) {
c .checkInterval = d
}
}
func WithChannelSendTimeout (d time .Duration ) ChannelOption {
return func (c *channel ) {
c .chanSendTimeout = d
}
}
type channel struct {
pubSub *PubSub
msgCh chan *Message
allCh chan interface {}
ping chan struct {}
chanSize int
chanSendTimeout time .Duration
checkInterval time .Duration
}
func newChannel(pubSub *PubSub , opts ...ChannelOption ) *channel {
c := &channel {
pubSub : pubSub ,
chanSize : 100 ,
chanSendTimeout : time .Minute ,
checkInterval : 3 * time .Second ,
}
for _ , opt := range opts {
opt (c )
}
if c .checkInterval > 0 {
c .initHealthCheck ()
}
return c
}
func (c *channel ) initHealthCheck () {
ctx := context .TODO ()
c .ping = make (chan struct {}, 1 )
go func () {
timer := time .NewTimer (time .Minute )
timer .Stop ()
for {
timer .Reset (c .checkInterval )
select {
case <- c .ping :
if !timer .Stop () {
<-timer .C
}
case <- timer .C :
if pingErr := c .pubSub .Ping (ctx ); pingErr != nil {
c .pubSub .mu .Lock ()
c .pubSub .reconnect (ctx , pingErr )
c .pubSub .mu .Unlock ()
}
case <- c .pubSub .exit :
return
}
}
}()
}
func (c *channel ) initMsgChan () {
ctx := context .TODO ()
c .msgCh = make (chan *Message , c .chanSize )
go func () {
timer := time .NewTimer (time .Minute )
timer .Stop ()
var errCount int
for {
msg , err := c .pubSub .Receive (ctx )
if err != nil {
if err == pool .ErrClosed {
close (c .msgCh )
return
}
if errCount > 0 {
time .Sleep (100 * time .Millisecond )
}
errCount ++
continue
}
errCount = 0
select {
case c .ping <- struct {}{}:
default :
}
switch msg := msg .(type ) {
case *Subscription :
case *Pong :
case *Message :
timer .Reset (c .chanSendTimeout )
select {
case c .msgCh <- msg :
if !timer .Stop () {
<-timer .C
}
case <- timer .C :
internal .Logger .Printf (
ctx , "redis: %s channel is full for %s (message is dropped)" ,
c , c .chanSendTimeout )
}
default :
internal .Logger .Printf (ctx , "redis: unknown message type: %T" , msg )
}
}
}()
}
func (c *channel ) initAllChan () {
ctx := context .TODO ()
c .allCh = make (chan interface {}, c .chanSize )
go func () {
timer := time .NewTimer (time .Minute )
timer .Stop ()
var errCount int
for {
msg , err := c .pubSub .Receive (ctx )
if err != nil {
if err == pool .ErrClosed {
close (c .allCh )
return
}
if errCount > 0 {
time .Sleep (100 * time .Millisecond )
}
errCount ++
continue
}
errCount = 0
select {
case c .ping <- struct {}{}:
default :
}
switch msg := msg .(type ) {
case *Pong :
case *Subscription , *Message :
timer .Reset (c .chanSendTimeout )
select {
case c .allCh <- msg :
if !timer .Stop () {
<-timer .C
}
case <- timer .C :
internal .Logger .Printf (
ctx , "redis: %s channel is full for %s (message is dropped)" ,
c , c .chanSendTimeout )
}
default :
internal .Logger .Printf (ctx , "redis: unknown message type: %T" , msg )
}
}
}()
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .