package pubsub
import (
"context"
"fmt"
"runtime"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/peer"
)
const (
defaultValidateQueueSize = 32
defaultValidateConcurrency = 1024
defaultValidateThrottle = 8192
)
type ValidationError struct {
Reason string
}
func (e ValidationError ) Error () string {
return e .Reason
}
type Validator func (context .Context , peer .ID , *Message ) bool
type ValidatorEx func (context .Context , peer .ID , *Message ) ValidationResult
type ValidationResult int
const (
ValidationAccept = ValidationResult (0 )
ValidationReject = ValidationResult (1 )
ValidationIgnore = ValidationResult (2 )
validationThrottled = ValidationResult (-1 )
)
type ValidatorOpt func (addVal *addValReq ) error
type validation struct {
p *PubSub
tracer *pubsubTracer
mx sync .Mutex
topicVals map [string ]*validatorImpl
defaultVals []*validatorImpl
validateQ chan *validateReq
validateThrottle chan struct {}
validateWorkers int
}
type validateReq struct {
vals []*validatorImpl
src peer .ID
msg *Message
}
type validatorImpl struct {
topic string
validate ValidatorEx
validateTimeout time .Duration
validateThrottle chan struct {}
validateInline bool
}
type addValReq struct {
topic string
validate interface {}
timeout time .Duration
throttle int
inline bool
resp chan error
}
type rmValReq struct {
topic string
resp chan error
}
func newValidation() *validation {
return &validation {
topicVals : make (map [string ]*validatorImpl ),
validateQ : make (chan *validateReq , defaultValidateQueueSize ),
validateThrottle : make (chan struct {}, defaultValidateThrottle ),
validateWorkers : runtime .NumCPU (),
}
}
func (v *validation ) Start (p *PubSub ) {
v .p = p
v .tracer = p .tracer
for i := 0 ; i < v .validateWorkers ; i ++ {
go v .validateWorker ()
}
}
func (v *validation ) AddValidator (req *addValReq ) {
val , err := v .makeValidator (req )
if err != nil {
req .resp <- err
return
}
v .mx .Lock ()
defer v .mx .Unlock ()
topic := val .topic
_ , ok := v .topicVals [topic ]
if ok {
req .resp <- fmt .Errorf ("duplicate validator for topic %s" , topic )
return
}
v .topicVals [topic ] = val
req .resp <- nil
}
func (v *validation ) makeValidator (req *addValReq ) (*validatorImpl , error ) {
makeValidatorEx := func (v Validator ) ValidatorEx {
return func (ctx context .Context , p peer .ID , msg *Message ) ValidationResult {
if v (ctx , p , msg ) {
return ValidationAccept
} else {
return ValidationReject
}
}
}
var validator ValidatorEx
switch v := req .validate .(type ) {
case func (ctx context .Context , p peer .ID , msg *Message ) bool :
validator = makeValidatorEx (Validator (v ))
case Validator :
validator = makeValidatorEx (v )
case func (ctx context .Context , p peer .ID , msg *Message ) ValidationResult :
validator = ValidatorEx (v )
case ValidatorEx :
validator = v
default :
topic := req .topic
if req .topic == "" {
topic = "(default)"
}
return nil , fmt .Errorf ("unknown validator type for topic %s; must be an instance of Validator or ValidatorEx" , topic )
}
val := &validatorImpl {
topic : req .topic ,
validate : validator ,
validateTimeout : 0 ,
validateThrottle : make (chan struct {}, defaultValidateConcurrency ),
validateInline : req .inline ,
}
if req .timeout > 0 {
val .validateTimeout = req .timeout
}
if req .throttle > 0 {
val .validateThrottle = make (chan struct {}, req .throttle )
}
return val , nil
}
func (v *validation ) RemoveValidator (req *rmValReq ) {
v .mx .Lock ()
defer v .mx .Unlock ()
topic := req .topic
_ , ok := v .topicVals [topic ]
if ok {
delete (v .topicVals , topic )
req .resp <- nil
} else {
req .resp <- fmt .Errorf ("no validator for topic %s" , topic )
}
}
func (v *validation ) PushLocal (msg *Message ) error {
v .p .tracer .PublishMessage (msg )
err := v .p .checkSigningPolicy (msg )
if err != nil {
return err
}
vals := v .getValidators (msg )
return v .validate (vals , msg .ReceivedFrom , msg , true )
}
func (v *validation ) Push (src peer .ID , msg *Message ) bool {
vals := v .getValidators (msg )
if len (vals ) > 0 || msg .Signature != nil {
select {
case v .validateQ <- &validateReq {vals , src , msg }:
default :
log .Debugf ("message validation throttled: queue full; dropping message from %s" , src )
v .tracer .RejectMessage (msg , RejectValidationQueueFull )
}
return false
}
return true
}
func (v *validation ) getValidators (msg *Message ) []*validatorImpl {
v .mx .Lock ()
defer v .mx .Unlock ()
var vals []*validatorImpl
vals = append (vals , v .defaultVals ...)
topic := msg .GetTopic ()
val , ok := v .topicVals [topic ]
if !ok {
return vals
}
return append (vals , val )
}
func (v *validation ) validateWorker () {
for {
select {
case req := <- v .validateQ :
v .validate (req .vals , req .src , req .msg , false )
case <- v .p .ctx .Done ():
return
}
}
}
func (v *validation ) validate (vals []*validatorImpl , src peer .ID , msg *Message , synchronous bool ) error {
if msg .Signature != nil {
if !v .validateSignature (msg ) {
log .Debugf ("message signature validation failed; dropping message from %s" , src )
v .tracer .RejectMessage (msg , RejectInvalidSignature )
return ValidationError {Reason : RejectInvalidSignature }
}
}
id := v .p .idGen .ID (msg )
if !v .p .markSeen (id ) {
v .tracer .DuplicateMessage (msg )
return nil
} else {
v .tracer .ValidateMessage (msg )
}
var inline , async []*validatorImpl
for _ , val := range vals {
if val .validateInline || synchronous {
inline = append (inline , val )
} else {
async = append (async , val )
}
}
result := ValidationAccept
loop :
for _ , val := range inline {
switch val .validateMsg (v .p .ctx , src , msg ) {
case ValidationAccept :
case ValidationReject :
result = ValidationReject
break loop
case ValidationIgnore :
result = ValidationIgnore
}
}
if result == ValidationReject {
log .Debugf ("message validation failed; dropping message from %s" , src )
v .tracer .RejectMessage (msg , RejectValidationFailed )
return ValidationError {Reason : RejectValidationFailed }
}
if len (async ) > 0 {
select {
case v .validateThrottle <- struct {}{}:
go func () {
v .doValidateTopic (async , src , msg , result )
<-v .validateThrottle
}()
default :
log .Debugf ("message validation throttled; dropping message from %s" , src )
v .tracer .RejectMessage (msg , RejectValidationThrottled )
}
return nil
}
if result == ValidationIgnore {
v .tracer .RejectMessage (msg , RejectValidationIgnored )
return ValidationError {Reason : RejectValidationIgnored }
}
select {
case v .p .sendMsg <- msg :
return nil
case <- v .p .ctx .Done ():
return v .p .ctx .Err ()
}
}
func (v *validation ) validateSignature (msg *Message ) bool {
err := verifyMessageSignature (msg .Message )
if err != nil {
log .Debugf ("signature verification error: %s" , err .Error())
return false
}
return true
}
func (v *validation ) doValidateTopic (vals []*validatorImpl , src peer .ID , msg *Message , r ValidationResult ) {
result := v .validateTopic (vals , src , msg )
if result == ValidationAccept && r != ValidationAccept {
result = r
}
switch result {
case ValidationAccept :
v .p .sendMsg <- msg
case ValidationReject :
log .Debugf ("message validation failed; dropping message from %s" , src )
v .tracer .RejectMessage (msg , RejectValidationFailed )
return
case ValidationIgnore :
log .Debugf ("message validation punted; ignoring message from %s" , src )
v .tracer .RejectMessage (msg , RejectValidationIgnored )
return
case validationThrottled :
log .Debugf ("message validation throttled; ignoring message from %s" , src )
v .tracer .RejectMessage (msg , RejectValidationThrottled )
default :
panic (fmt .Errorf ("unexpected validation result: %d" , result ))
}
}
func (v *validation ) validateTopic (vals []*validatorImpl , src peer .ID , msg *Message ) ValidationResult {
if len (vals ) == 1 {
return v .validateSingleTopic (vals [0 ], src , msg )
}
ctx , cancel := context .WithCancel (v .p .ctx )
defer cancel ()
rch := make (chan ValidationResult , len (vals ))
rcount := 0
for _ , val := range vals {
rcount ++
select {
case val .validateThrottle <- struct {}{}:
go func (val *validatorImpl ) {
rch <- val .validateMsg (ctx , src , msg )
<-val .validateThrottle
}(val )
default :
log .Debugf ("validation throttled for topic %s" , val .topic )
rch <- validationThrottled
}
}
result := ValidationAccept
loop :
for i := 0 ; i < rcount ; i ++ {
switch <-rch {
case ValidationAccept :
case ValidationReject :
result = ValidationReject
break loop
case ValidationIgnore :
if result != validationThrottled {
result = ValidationIgnore
}
case validationThrottled :
result = validationThrottled
}
}
return result
}
func (v *validation ) validateSingleTopic (val *validatorImpl , src peer .ID , msg *Message ) ValidationResult {
select {
case val .validateThrottle <- struct {}{}:
res := val .validateMsg (v .p .ctx , src , msg )
<-val .validateThrottle
return res
default :
log .Debugf ("validation throttled for topic %s" , val .topic )
return validationThrottled
}
}
func (val *validatorImpl ) validateMsg (ctx context .Context , src peer .ID , msg *Message ) ValidationResult {
start := time .Now ()
defer func () {
log .Debugf ("validation done; took %s" , time .Since (start ))
}()
if val .validateTimeout > 0 {
var cancel func ()
ctx , cancel = context .WithTimeout (ctx , val .validateTimeout )
defer cancel ()
}
r := val .validate (ctx , src , msg )
switch r {
case ValidationAccept :
fallthrough
case ValidationReject :
fallthrough
case ValidationIgnore :
return r
default :
log .Warnf ("Unexpected result from validator: %d; ignoring message" , r )
return ValidationIgnore
}
}
func WithDefaultValidator (val interface {}, opts ...ValidatorOpt ) Option {
return func (ps *PubSub ) error {
addVal := &addValReq {
validate : val ,
}
for _ , opt := range opts {
err := opt (addVal )
if err != nil {
return err
}
}
val , err := ps .val .makeValidator (addVal )
if err != nil {
return err
}
ps .val .defaultVals = append (ps .val .defaultVals , val )
return nil
}
}
func WithValidateQueueSize (n int ) Option {
return func (ps *PubSub ) error {
if n > 0 {
ps .val .validateQ = make (chan *validateReq , n )
return nil
}
return fmt .Errorf ("validate queue size must be > 0" )
}
}
func WithValidateThrottle (n int ) Option {
return func (ps *PubSub ) error {
ps .val .validateThrottle = make (chan struct {}, n )
return nil
}
}
func WithValidateWorkers (n int ) Option {
return func (ps *PubSub ) error {
if n > 0 {
ps .val .validateWorkers = n
return nil
}
return fmt .Errorf ("number of validation workers must be > 0" )
}
}
func WithValidatorTimeout (timeout time .Duration ) ValidatorOpt {
return func (addVal *addValReq ) error {
addVal .timeout = timeout
return nil
}
}
func WithValidatorConcurrency (n int ) ValidatorOpt {
return func (addVal *addValReq ) error {
addVal .throttle = n
return nil
}
}
func WithValidatorInline (inline bool ) ValidatorOpt {
return func (addVal *addValReq ) error {
addVal .inline = inline
return nil
}
}
The pages are generated with Golds v0.8.2 . (GOOS=linux GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @zigo_101 (reachable from the left QR code) to get the latest news of Golds .