fix SASL stalling due to duplicated callbacks

Each reconnection would add redundant callbacks for CAP and the
SASL numerics. This would result in `AUTHENTICATE PLAIN` being sent
multiple times (once for each callback), which would confuse the server.
This commit is contained in:
Shivaram Lingamneni 2021-02-11 10:40:29 -05:00
parent 8e7ce4b5a1
commit ca8c401467
4 changed files with 43 additions and 15 deletions

17
irc.go
View File

@ -510,10 +510,20 @@ func (irc *Connection) Connect(server string) error {
// Negotiate IRCv3 capabilities // Negotiate IRCv3 capabilities
func (irc *Connection) negotiateCaps() error { func (irc *Connection) negotiateCaps() error {
irc.RequestCaps = nil
irc.AcknowledgedCaps = nil
var negotiationCallbacks []CallbackID
defer func() {
for _, callback := range negotiationCallbacks {
irc.RemoveCallback(callback.EventCode, callback.ID)
}
}()
saslResChan := make(chan *SASLResult) saslResChan := make(chan *SASLResult)
if irc.UseSASL { if irc.UseSASL {
irc.RequestCaps = append(irc.RequestCaps, "sasl") irc.RequestCaps = append(irc.RequestCaps, "sasl")
irc.setupSASLCallbacks(saslResChan) negotiationCallbacks = irc.setupSASLCallbacks(saslResChan)
} }
if len(irc.RequestCaps) == 0 { if len(irc.RequestCaps) == 0 {
@ -521,7 +531,7 @@ func (irc *Connection) negotiateCaps() error {
} }
cap_chan := make(chan bool, len(irc.RequestCaps)) cap_chan := make(chan bool, len(irc.RequestCaps))
irc.AddCallback("CAP", func(e *Event) { id := irc.AddCallback("CAP", func(e *Event) {
if len(e.Arguments) != 3 { if len(e.Arguments) != 3 {
return return
} }
@ -554,6 +564,7 @@ func (irc *Connection) negotiateCaps() error {
} }
} }
}) })
negotiationCallbacks = append(negotiationCallbacks, CallbackID{"CAP", id})
irc.pwrite <- "CAP LS\r\n" irc.pwrite <- "CAP LS\r\n"
@ -561,11 +572,9 @@ func (irc *Connection) negotiateCaps() error {
select { select {
case res := <-saslResChan: case res := <-saslResChan:
if res.Failed { if res.Failed {
close(saslResChan)
return res.Err return res.Err
} }
case <-time.After(CAP_TIMEOUT): case <-time.After(CAP_TIMEOUT):
close(saslResChan)
// Raise an error if we can't authenticate with SASL. // Raise an error if we can't authenticate with SASL.
return errors.New("SASL setup timed out. Does the server support SASL?") return errors.New("SASL setup timed out. Does the server support SASL?")
} }

View File

@ -9,6 +9,12 @@ import (
"time" "time"
) )
// Tuple type for uniquely identifying callbacks
type CallbackID struct {
EventCode string
ID int
}
// Register a callback to a connection and event code. A callback is a function // Register a callback to a connection and event code. A callback is a function
// which takes only an Event pointer as parameter. Valid event codes are all // which takes only an Event pointer as parameter. Valid event codes are all
// IRC/CTCP commands and error/response codes. To register a callback for all // IRC/CTCP commands and error/response codes. To register a callback for all
@ -16,16 +22,14 @@ import (
// registered callback for later management. // registered callback for later management.
func (irc *Connection) AddCallback(eventcode string, callback func(*Event)) int { func (irc *Connection) AddCallback(eventcode string, callback func(*Event)) int {
eventcode = strings.ToUpper(eventcode) eventcode = strings.ToUpper(eventcode)
id := 0
irc.eventsMutex.Lock() irc.eventsMutex.Lock()
_, ok := irc.events[eventcode] _, ok := irc.events[eventcode]
if !ok { if !ok {
irc.events[eventcode] = make(map[int]func(*Event)) irc.events[eventcode] = make(map[int]func(*Event))
id = 0
} else {
id = len(irc.events[eventcode])
} }
id := irc.idCounter
irc.idCounter++
irc.events[eventcode][id] = callback irc.events[eventcode][id] = callback
irc.eventsMutex.Unlock() irc.eventsMutex.Unlock()
return id return id

View File

@ -22,8 +22,8 @@ func listContains(list string, value string) bool {
return false return false
} }
func (irc *Connection) setupSASLCallbacks(result chan<- *SASLResult) { func (irc *Connection) setupSASLCallbacks(result chan<- *SASLResult) (callbacks []CallbackID) {
irc.AddCallback("CAP", func(e *Event) { id := irc.AddCallback("CAP", func(e *Event) {
if len(e.Arguments) == 3 { if len(e.Arguments) == 3 {
if e.Arguments[1] == "LS" { if e.Arguments[1] == "LS" {
if !listContains(e.Arguments[2], "sasl") { if !listContains(e.Arguments[2], "sasl") {
@ -38,26 +38,39 @@ func (irc *Connection) setupSASLCallbacks(result chan<- *SASLResult) {
} }
} }
}) })
irc.AddCallback("AUTHENTICATE", func(e *Event) { callbacks = append(callbacks, CallbackID{"CAP", id})
id = irc.AddCallback("AUTHENTICATE", func(e *Event) {
str := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s\x00%s\x00%s", irc.SASLLogin, irc.SASLLogin, irc.SASLPassword))) str := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s\x00%s\x00%s", irc.SASLLogin, irc.SASLLogin, irc.SASLPassword)))
irc.SendRaw("AUTHENTICATE " + str) irc.SendRaw("AUTHENTICATE " + str)
}) })
irc.AddCallback("901", func(e *Event) { callbacks = append(callbacks, CallbackID{"AUTHENTICATE", id})
id = irc.AddCallback("901", func(e *Event) {
irc.SendRaw("CAP END") irc.SendRaw("CAP END")
irc.SendRaw("QUIT") irc.SendRaw("QUIT")
result <- &SASLResult{true, errors.New(e.Arguments[1])} result <- &SASLResult{true, errors.New(e.Arguments[1])}
}) })
irc.AddCallback("902", func(e *Event) { callbacks = append(callbacks, CallbackID{"901", id})
id = irc.AddCallback("902", func(e *Event) {
irc.SendRaw("CAP END") irc.SendRaw("CAP END")
irc.SendRaw("QUIT") irc.SendRaw("QUIT")
result <- &SASLResult{true, errors.New(e.Arguments[1])} result <- &SASLResult{true, errors.New(e.Arguments[1])}
}) })
irc.AddCallback("903", func(e *Event) { callbacks = append(callbacks, CallbackID{"902", id})
id = irc.AddCallback("903", func(e *Event) {
result <- &SASLResult{false, nil} result <- &SASLResult{false, nil}
}) })
irc.AddCallback("904", func(e *Event) { callbacks = append(callbacks, CallbackID{"903", id})
id = irc.AddCallback("904", func(e *Event) {
irc.SendRaw("CAP END") irc.SendRaw("CAP END")
irc.SendRaw("QUIT") irc.SendRaw("QUIT")
result <- &SASLResult{true, errors.New(e.Arguments[1])} result <- &SASLResult{true, errors.New(e.Arguments[1])}
}) })
callbacks = append(callbacks, CallbackID{"904", id})
return
} }

View File

@ -62,6 +62,8 @@ type Connection struct {
stopped bool stopped bool
quit bool //User called Quit, do not reconnect. quit bool //User called Quit, do not reconnect.
idCounter int // assign unique IDs to callbacks
} }
// A struct to represent an event. // A struct to represent an event.