diff --git a/irc.go b/irc.go index 1b52f34..4634390 100644 --- a/irc.go +++ b/irc.go @@ -87,6 +87,17 @@ func (irc *Connection) readLoop() { } } +// Unescape tag values as defined in the IRCv3.2 message tags spec +// http://ircv3.net/specs/core/message-tags-3.2.html +func unescapeTagValue(value string) string { + value = strings.Replace(value, "\\:", ";", -1) + value = strings.Replace(value, "\\s", " ", -1) + value = strings.Replace(value, "\\\\", "\\", -1) + value = strings.Replace(value, "\\r", "\r", -1) + value = strings.Replace(value, "\\n", "\n", -1) + return value +} + //Parse raw irc messages func parseToEvent(msg string) (*Event, error) { msg = strings.TrimSuffix(msg, "\n") //Remove \r\n @@ -95,6 +106,26 @@ func parseToEvent(msg string) (*Event, error) { if len(msg) < 5 { return nil, errors.New("Malformed msg from server") } + + if msg[0] == '@' { + // IRCv3 Message Tags + if i := strings.Index(msg, " "); i > -1 { + event.Tags = make(map[string]string) + tags := strings.Split(msg[1:i], ";") + for _, data := range tags { + parts := strings.SplitN(data, "=", 2) + if len(parts) == 1 { + event.Tags[parts[0]] = "" + } else { + event.Tags[parts[0]] = unescapeTagValue(parts[1]) + } + } + msg = msg[i+1 : len(msg)] + } else { + return nil, errors.New("Malformed msg from server") + } + } + if msg[0] == ':' { if i := strings.Index(msg, " "); i > -1 { event.Source = msg[1:i] @@ -432,26 +463,84 @@ func (irc *Connection) Connect(server string) error { irc.pwrite <- fmt.Sprintf("PASS %s\r\n", irc.Password) } - resChan := make(chan *SASLResult) + err = irc.negotiateCaps() + if err != nil { + return err + } + + irc.pwrite <- fmt.Sprintf("NICK %s\r\n", irc.nick) + irc.pwrite <- fmt.Sprintf("USER %s 0.0.0.0 0.0.0.0 :%s\r\n", irc.user, irc.user) + return nil +} + +// Negotiate IRCv3 capabilities +func (irc *Connection) negotiateCaps() error { + saslResChan := make(chan *SASLResult) + if irc.UseSASL { + irc.RequestCaps = append(irc.RequestCaps, "sasl") + irc.setupSASLCallbacks(saslResChan) + } + + if len(irc.RequestCaps) == 0 { + return nil + } + + cap_chan := make(chan bool, len(irc.RequestCaps)) + irc.AddCallback("CAP", func(e *Event) { + if len(e.Arguments) != 3 { + return + } + command := e.Arguments[1] + + if command == "LS" { + missing_caps := len(irc.RequestCaps) + for _, cap_name := range strings.Split(e.Arguments[2], " ") { + for _, req_cap := range irc.RequestCaps { + if cap_name == req_cap { + irc.pwrite <- fmt.Sprintf("CAP REQ :%s\r\n", cap_name) + missing_caps-- + } + } + } + + for i := 0; i < missing_caps; i++ { + cap_chan <- true + } + } else if command == "ACK" || command == "NAK" { + for _, cap_name := range strings.Split(strings.TrimSpace(e.Arguments[2]), " ") { + if cap_name == "" { + continue + } + + if command == "ACK" { + irc.AcknowledgedCaps = append(irc.AcknowledgedCaps, cap_name) + } + cap_chan <- true + } + } + }) + + irc.pwrite <- "CAP LS\r\n" + if irc.UseSASL { - irc.setupSASLCallbacks(resChan) - irc.pwrite <- fmt.Sprintf("CAP LS\r\n") - // request SASL - irc.pwrite <- fmt.Sprintf("CAP REQ :sasl\r\n") - // if sasl request doesn't complete in 15 seconds, close chan and timeout select { - case res := <-resChan: + case res := <-saslResChan: if res.Failed { - close(resChan) + close(saslResChan) return res.Err } case <-time.After(time.Second * 15): - close(resChan) + close(saslResChan) return errors.New("SASL setup timed out. This shouldn't happen.") } } - irc.pwrite <- fmt.Sprintf("NICK %s\r\n", irc.nick) - irc.pwrite <- fmt.Sprintf("USER %s 0.0.0.0 0.0.0.0 :%s\r\n", irc.user, irc.user) + + // Wait for all capabilities to be ACKed or NAKed before ending negotiation + for i := 0; i < len(irc.RequestCaps); i++ { + <-cap_chan + } + irc.pwrite <- fmt.Sprintf("CAP END\r\n") + return nil } diff --git a/irc_parse_test.go b/irc_parse_test.go new file mode 100644 index 0000000..447a590 --- /dev/null +++ b/irc_parse_test.go @@ -0,0 +1,47 @@ +package irc + +import ( + "fmt" + "testing" +) + +func checkResult(t *testing.T, event *Event) { + if event.Nick != "nick" { + t.Fatal("Parse failed: nick") + } + if event.User != "~user" { + t.Fatal("Parse failed: user") + } + if event.Code != "PRIVMSG" { + t.Fatal("Parse failed: code") + } + if event.Arguments[0] != "#channel" { + t.Fatal("Parse failed: channel") + } + if event.Arguments[1] != "message text" { + t.Fatal("Parse failed: message") + } +} + +func TestParse(t *testing.T) { + event, err := parseToEvent(":nick!~user@host PRIVMSG #channel :message text") + if err != nil { + t.Fatal("Parse PRIVMSG failed") + } + checkResult(t, event) +} + +func TestParseTags(t *testing.T) { + event, err := parseToEvent("@tag;+tag2=raw+:=,escaped\\:\\s\\\\ :nick!~user@host PRIVMSG #channel :message text") + if err != nil { + t.Fatal("Parse PRIVMSG with tags failed") + } + checkResult(t, event) + fmt.Printf("%s", event.Tags) + if _, ok := event.Tags["tag"]; !ok { + t.Fatal("Parsing value-less tag failed") + } + if event.Tags["+tag2"] != "raw+:=,escaped; \\" { + t.Fatal("Parsing tag failed") + } +} diff --git a/irc_sasl.go b/irc_sasl.go index e5ff9e3..fd9df7b 100644 --- a/irc_sasl.go +++ b/irc_sasl.go @@ -43,7 +43,6 @@ func (irc *Connection) setupSASLCallbacks(result chan<- *SASLResult) { result <- &SASLResult{true, errors.New(e.Arguments[1])} }) irc.AddCallback("903", func(e *Event) { - irc.SendRaw("CAP END") result <- &SASLResult{false, nil} }) irc.AddCallback("904", func(e *Event) { diff --git a/irc_sasl_test.go b/irc_sasl_test.go index 6d22736..3c2fe47 100644 --- a/irc_sasl_test.go +++ b/irc_sasl_test.go @@ -34,7 +34,7 @@ func TestConnectionSASL(t *testing.T) { err := irccon.Connect(SASLServer) if err != nil { - t.Fatal("SASL failed") + t.Fatalf("SASL failed: %s", err) } irccon.Loop() } diff --git a/irc_struct.go b/irc_struct.go index 22d08d8..1161900 100644 --- a/irc_struct.go +++ b/irc_struct.go @@ -15,20 +15,22 @@ import ( type Connection struct { sync.Mutex sync.WaitGroup - Debug bool - Error chan error - Password string - UseTLS bool - UseSASL bool - SASLLogin string - SASLPassword string - SASLMech string - TLSConfig *tls.Config - Version string - Timeout time.Duration - PingFreq time.Duration - KeepAlive time.Duration - Server string + Debug bool + Error chan error + Password string + UseTLS bool + UseSASL bool + RequestCaps []string + AcknowledgedCaps []string + SASLLogin string + SASLPassword string + SASLMech string + TLSConfig *tls.Config + Version string + Timeout time.Duration + PingFreq time.Duration + KeepAlive time.Duration + Server string socket net.Conn pwrite chan string @@ -61,6 +63,7 @@ type Event struct { Source string // User string // Arguments []string + Tags map[string]string Connection *Connection }