diff --git a/irc.go b/irc.go index da90edb..03eb318 100644 --- a/irc.go +++ b/irc.go @@ -510,10 +510,20 @@ func (irc *Connection) Connect(server string) error { // Negotiate IRCv3 capabilities func (irc *Connection) negotiateCaps() error { + irc.RequestCaps = nil + irc.AcknowledgedCaps = nil + + var negotiationCallbacks []CallbackID + defer func() { + for _, callback := range negotiationCallbacks { + irc.RemoveCallback(callback.EventCode, callback.ID) + } + }() + saslResChan := make(chan *SASLResult) if irc.UseSASL { irc.RequestCaps = append(irc.RequestCaps, "sasl") - irc.setupSASLCallbacks(saslResChan) + negotiationCallbacks = irc.setupSASLCallbacks(saslResChan) } if len(irc.RequestCaps) == 0 { @@ -521,7 +531,7 @@ func (irc *Connection) negotiateCaps() error { } cap_chan := make(chan bool, len(irc.RequestCaps)) - irc.AddCallback("CAP", func(e *Event) { + id := irc.AddCallback("CAP", func(e *Event) { if len(e.Arguments) != 3 { return } @@ -554,6 +564,7 @@ func (irc *Connection) negotiateCaps() error { } } }) + negotiationCallbacks = append(negotiationCallbacks, CallbackID{"CAP", id}) irc.pwrite <- "CAP LS\r\n" @@ -561,11 +572,9 @@ func (irc *Connection) negotiateCaps() error { select { case res := <-saslResChan: if res.Failed { - close(saslResChan) return res.Err } case <-time.After(CAP_TIMEOUT): - close(saslResChan) // Raise an error if we can't authenticate with SASL. return errors.New("SASL setup timed out. Does the server support SASL?") } diff --git a/irc_callback.go b/irc_callback.go index 627e639..c19c211 100644 --- a/irc_callback.go +++ b/irc_callback.go @@ -9,6 +9,12 @@ import ( "time" ) +// Tuple type for uniquely identifying callbacks +type CallbackID struct { + EventCode string + ID int +} + // Register a callback to a connection and event code. A callback is a function // which takes only an Event pointer as parameter. Valid event codes are all // IRC/CTCP commands and error/response codes. To register a callback for all @@ -16,16 +22,14 @@ import ( // registered callback for later management. func (irc *Connection) AddCallback(eventcode string, callback func(*Event)) int { eventcode = strings.ToUpper(eventcode) - id := 0 irc.eventsMutex.Lock() _, ok := irc.events[eventcode] if !ok { irc.events[eventcode] = make(map[int]func(*Event)) - id = 0 - } else { - id = len(irc.events[eventcode]) } + id := irc.idCounter + irc.idCounter++ irc.events[eventcode][id] = callback irc.eventsMutex.Unlock() return id diff --git a/irc_sasl.go b/irc_sasl.go index fb8c292..be02380 100644 --- a/irc_sasl.go +++ b/irc_sasl.go @@ -22,8 +22,8 @@ func listContains(list string, value string) bool { return false } -func (irc *Connection) setupSASLCallbacks(result chan<- *SASLResult) { - irc.AddCallback("CAP", func(e *Event) { +func (irc *Connection) setupSASLCallbacks(result chan<- *SASLResult) (callbacks []CallbackID) { + id := irc.AddCallback("CAP", func(e *Event) { if len(e.Arguments) == 3 { if e.Arguments[1] == "LS" { if !listContains(e.Arguments[2], "sasl") { @@ -38,26 +38,39 @@ func (irc *Connection) setupSASLCallbacks(result chan<- *SASLResult) { } } }) - irc.AddCallback("AUTHENTICATE", func(e *Event) { + callbacks = append(callbacks, CallbackID{"CAP", id}) + + id = irc.AddCallback("AUTHENTICATE", func(e *Event) { str := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s\x00%s\x00%s", irc.SASLLogin, irc.SASLLogin, irc.SASLPassword))) irc.SendRaw("AUTHENTICATE " + str) }) - irc.AddCallback("901", func(e *Event) { + callbacks = append(callbacks, CallbackID{"AUTHENTICATE", id}) + + id = irc.AddCallback("901", func(e *Event) { irc.SendRaw("CAP END") irc.SendRaw("QUIT") result <- &SASLResult{true, errors.New(e.Arguments[1])} }) - irc.AddCallback("902", func(e *Event) { + callbacks = append(callbacks, CallbackID{"901", id}) + + id = irc.AddCallback("902", func(e *Event) { irc.SendRaw("CAP END") irc.SendRaw("QUIT") result <- &SASLResult{true, errors.New(e.Arguments[1])} }) - irc.AddCallback("903", func(e *Event) { + callbacks = append(callbacks, CallbackID{"902", id}) + + id = irc.AddCallback("903", func(e *Event) { result <- &SASLResult{false, nil} }) - irc.AddCallback("904", func(e *Event) { + callbacks = append(callbacks, CallbackID{"903", id}) + + id = irc.AddCallback("904", func(e *Event) { irc.SendRaw("CAP END") irc.SendRaw("QUIT") result <- &SASLResult{true, errors.New(e.Arguments[1])} }) + callbacks = append(callbacks, CallbackID{"904", id}) + + return } diff --git a/irc_struct.go b/irc_struct.go index 948f49e..9175537 100644 --- a/irc_struct.go +++ b/irc_struct.go @@ -62,6 +62,8 @@ type Connection struct { stopped bool quit bool //User called Quit, do not reconnect. + + idCounter int // assign unique IDs to callbacks } // A struct to represent an event.