diff --git a/irc.go b/irc.go index af50c74..bcfe9ac 100644 --- a/irc.go +++ b/irc.go @@ -461,26 +461,84 @@ func (irc *Connection) Connect(server string) error { irc.pwrite <- fmt.Sprintf("PASS %s\r\n", irc.Password) } - resChan := make(chan *SASLResult) + err = irc.negotiateCaps() + if err != nil { + return err + } + + irc.pwrite <- fmt.Sprintf("NICK %s\r\n", irc.nick) + irc.pwrite <- fmt.Sprintf("USER %s 0.0.0.0 0.0.0.0 :%s\r\n", irc.user, irc.user) + return nil +} + +// Negotiate IRCv3 capabilities +func (irc *Connection) negotiateCaps() error { + saslResChan := make(chan *SASLResult) + if irc.UseSASL { + irc.RequestCaps = append(irc.RequestCaps, "sasl") + irc.setupSASLCallbacks(saslResChan) + } + + if len(irc.RequestCaps) == 0 { + return nil + } + + cap_chan := make(chan bool, len(irc.RequestCaps)) + irc.AddCallback("CAP", func(e *Event) { + if len(e.Arguments) != 3 { + return + } + command := e.Arguments[1] + + if command == "LS" { + missing_caps := len(irc.RequestCaps) + for _, cap_name := range strings.Split(e.Arguments[2], " ") { + for _, req_cap := range irc.RequestCaps { + if cap_name == req_cap { + irc.pwrite <- fmt.Sprintf("CAP REQ :%s\r\n", cap_name) + missing_caps-- + } + } + } + + for i := 0; i < missing_caps; i++ { + cap_chan <- true + } + } else if command == "ACK" || command == "NAK" { + for _, cap_name := range strings.Split(strings.TrimSpace(e.Arguments[2]), " ") { + if cap_name == "" { + continue + } + + if command == "ACK" { + irc.AcknowledgedCaps = append(irc.AcknowledgedCaps, cap_name) + } + cap_chan <- true + } + } + }) + + irc.pwrite <- "CAP LS\r\n" + if irc.UseSASL { - irc.setupSASLCallbacks(resChan) - irc.pwrite <- fmt.Sprintf("CAP LS\r\n") - // request SASL - irc.pwrite <- fmt.Sprintf("CAP REQ :sasl\r\n") - // if sasl request doesn't complete in 15 seconds, close chan and timeout select { - case res := <-resChan: + case res := <-saslResChan: if res.Failed { - close(resChan) + close(saslResChan) return res.Err } case <-time.After(time.Second * 15): - close(resChan) + close(saslResChan) return errors.New("SASL setup timed out. This shouldn't happen.") } } - irc.pwrite <- fmt.Sprintf("NICK %s\r\n", irc.nick) - irc.pwrite <- fmt.Sprintf("USER %s 0.0.0.0 0.0.0.0 :%s\r\n", irc.user, irc.user) + + // Wait for all capabilities to be ACKed or NAKed before ending negotiation + for i := 0; i < len(irc.RequestCaps); i++ { + <-cap_chan + } + irc.pwrite <- fmt.Sprintf("CAP END\r\n") + return nil } diff --git a/irc_sasl.go b/irc_sasl.go index e5ff9e3..fd9df7b 100644 --- a/irc_sasl.go +++ b/irc_sasl.go @@ -43,7 +43,6 @@ func (irc *Connection) setupSASLCallbacks(result chan<- *SASLResult) { result <- &SASLResult{true, errors.New(e.Arguments[1])} }) irc.AddCallback("903", func(e *Event) { - irc.SendRaw("CAP END") result <- &SASLResult{false, nil} }) irc.AddCallback("904", func(e *Event) { diff --git a/irc_sasl_test.go b/irc_sasl_test.go index 6d22736..3c2fe47 100644 --- a/irc_sasl_test.go +++ b/irc_sasl_test.go @@ -34,7 +34,7 @@ func TestConnectionSASL(t *testing.T) { err := irccon.Connect(SASLServer) if err != nil { - t.Fatal("SASL failed") + t.Fatalf("SASL failed: %s", err) } irccon.Loop() } diff --git a/irc_struct.go b/irc_struct.go index e509dbb..c064cb8 100644 --- a/irc_struct.go +++ b/irc_struct.go @@ -15,20 +15,22 @@ import ( type Connection struct { sync.Mutex sync.WaitGroup - Debug bool - Error chan error - Password string - UseTLS bool - UseSASL bool - SASLLogin string - SASLPassword string - SASLMech string - TLSConfig *tls.Config - Version string - Timeout time.Duration - PingFreq time.Duration - KeepAlive time.Duration - Server string + Debug bool + Error chan error + Password string + UseTLS bool + UseSASL bool + RequestCaps []string + AcknowledgedCaps []string + SASLLogin string + SASLPassword string + SASLMech string + TLSConfig *tls.Config + Version string + Timeout time.Duration + PingFreq time.Duration + KeepAlive time.Duration + Server string socket net.Conn pwrite chan string