diff --git a/irc.go b/irc.go index e430413..8d11e55 100644 --- a/irc.go +++ b/irc.go @@ -74,9 +74,9 @@ func (irc *Connection) readLoop() { irc.Log.Printf("<-- %s\n", strings.TrimSpace(msg)) } - irc.Lock() + irc.lastMessageMutex.Lock() irc.lastMessage = time.Now() - irc.Unlock() + irc.lastMessageMutex.Unlock() event, err := parseToEvent(msg) event.Connection = irc if err == nil { @@ -87,6 +87,17 @@ func (irc *Connection) readLoop() { } } +// Unescape tag values as defined in the IRCv3.2 message tags spec +// http://ircv3.net/specs/core/message-tags-3.2.html +func unescapeTagValue(value string) string { + value = strings.Replace(value, "\\:", ";", -1) + value = strings.Replace(value, "\\s", " ", -1) + value = strings.Replace(value, "\\\\", "\\", -1) + value = strings.Replace(value, "\\r", "\r", -1) + value = strings.Replace(value, "\\n", "\n", -1) + return value +} + //Parse raw irc messages func parseToEvent(msg string) (*Event, error) { msg = strings.TrimSuffix(msg, "\n") //Remove \r\n @@ -95,6 +106,26 @@ func parseToEvent(msg string) (*Event, error) { if len(msg) < 5 { return nil, errors.New("Malformed msg from server") } + + if msg[0] == '@' { + // IRCv3 Message Tags + if i := strings.Index(msg, " "); i > -1 { + event.Tags = make(map[string]string) + tags := strings.Split(msg[1:i], ";") + for _, data := range tags { + parts := strings.SplitN(data, "=", 2) + if len(parts) == 1 { + event.Tags[parts[0]] = "" + } else { + event.Tags[parts[0]] = unescapeTagValue(parts[1]) + } + } + msg = msg[i+1 : len(msg)] + } else { + return nil, errors.New("Malformed msg from server") + } + } + if msg[0] == ':' { if i := strings.Index(msg, " "); i > -1 { event.Source = msg[1:i] @@ -166,9 +197,11 @@ func (irc *Connection) pingLoop() { select { case <-ticker.C: //Ping if we haven't received anything from the server within the keep alive period + irc.lastMessageMutex.Lock() if time.Since(irc.lastMessage) >= irc.KeepAlive { irc.SendRawf("PING %d", time.Now().UnixNano()) } + irc.lastMessageMutex.Unlock() case <-ticker2.C: //Ping at the ping frequency irc.SendRawf("PING %d", time.Now().UnixNano()) @@ -355,6 +388,20 @@ func (irc *Connection) Connected() bool { // A disconnect sends all buffered messages (if possible), // stops all goroutines and then closes the socket. func (irc *Connection) Disconnect() { + irc.Lock() + defer irc.Unlock() + + if irc.end != nil { + close(irc.end) + } + + irc.end = nil + + if irc.pwrite != nil { + close(irc.pwrite) + } + + irc.Wait() if irc.socket != nil { irc.socket.Close() } @@ -435,26 +482,96 @@ func (irc *Connection) Connect(server string) error { irc.pwrite <- fmt.Sprintf("PASS %s\r\n", irc.Password) } - resChan := make(chan *SASLResult) + err = irc.negotiateCaps() + if err != nil { + return err + } + + realname := irc.user + if irc.RealName != "" { + realname = irc.RealName + } + + irc.pwrite <- fmt.Sprintf("NICK %s\r\n", irc.nick) + irc.pwrite <- fmt.Sprintf("USER %s 0.0.0.0 0.0.0.0 :%s\r\n", irc.user, realname) + return nil +} + +// Negotiate IRCv3 capabilities +func (irc *Connection) negotiateCaps() error { + saslResChan := make(chan *SASLResult) + if irc.UseSASL { + irc.RequestCaps = append(irc.RequestCaps, "sasl") + irc.setupSASLCallbacks(saslResChan) + } + + if len(irc.RequestCaps) == 0 { + return nil + } + + cap_chan := make(chan bool, len(irc.RequestCaps)) + irc.AddCallback("CAP", func(e *Event) { + if len(e.Arguments) != 3 { + return + } + command := e.Arguments[1] + + if command == "LS" { + missing_caps := len(irc.RequestCaps) + for _, cap_name := range strings.Split(e.Arguments[2], " ") { + for _, req_cap := range irc.RequestCaps { + if cap_name == req_cap { + irc.pwrite <- fmt.Sprintf("CAP REQ :%s\r\n", cap_name) + missing_caps-- + } + } + } + + for i := 0; i < missing_caps; i++ { + cap_chan <- true + } + } else if command == "ACK" || command == "NAK" { + for _, cap_name := range strings.Split(strings.TrimSpace(e.Arguments[2]), " ") { + if cap_name == "" { + continue + } + + if command == "ACK" { + irc.AcknowledgedCaps = append(irc.AcknowledgedCaps, cap_name) + } + cap_chan <- true + } + } + }) + + irc.pwrite <- "CAP LS\r\n" + if irc.UseSASL { - irc.setupSASLCallbacks(resChan) - irc.pwrite <- fmt.Sprintf("CAP LS\r\n") - // request SASL - irc.pwrite <- fmt.Sprintf("CAP REQ :sasl\r\n") - // if sasl request doesn't complete in 15 seconds, close chan and timeout select { - case res := <-resChan: + case res := <-saslResChan: if res.Failed { - close(resChan) + close(saslResChan) return res.Err } case <-time.After(time.Second * 15): - close(resChan) + close(saslResChan) return errors.New("SASL setup timed out. This shouldn't happen.") } } + + // Wait for all capabilities to be ACKed or NAKed before ending negotiation + for i := 0; i < len(irc.RequestCaps); i++ { + <-cap_chan + } + irc.pwrite <- fmt.Sprintf("CAP END\r\n") + + realname := irc.user + if irc.RealName != "" { + realname = irc.RealName + } + irc.pwrite <- fmt.Sprintf("NICK %s\r\n", irc.nick) - irc.pwrite <- fmt.Sprintf("USER %s 0.0.0.0 0.0.0.0 :%s\r\n", irc.user, irc.user) + irc.pwrite <- fmt.Sprintf("USER %s 0.0.0.0 0.0.0.0 :%s\r\n", irc.user, realname) return nil } diff --git a/irc_callback.go b/irc_callback.go index d389f73..15e7191 100644 --- a/irc_callback.go +++ b/irc_callback.go @@ -13,13 +13,17 @@ import ( func (irc *Connection) AddCallback(eventcode string, callback func(*Event)) int { eventcode = strings.ToUpper(eventcode) id := 0 - if _, ok := irc.events[eventcode]; !ok { + + irc.eventsMutex.Lock() + _, ok := irc.events[eventcode] + if !ok { irc.events[eventcode] = make(map[int]func(*Event)) id = 0 } else { id = len(irc.events[eventcode]) } irc.events[eventcode][id] = callback + irc.eventsMutex.Unlock() return id } @@ -28,15 +32,20 @@ func (irc *Connection) AddCallback(eventcode string, callback func(*Event)) int func (irc *Connection) RemoveCallback(eventcode string, i int) bool { eventcode = strings.ToUpper(eventcode) - if event, ok := irc.events[eventcode]; ok { + irc.eventsMutex.Lock() + event, ok := irc.events[eventcode] + if ok { if _, ok := event[i]; ok { delete(irc.events[eventcode], i) + irc.eventsMutex.Unlock() return true } irc.Log.Printf("Event found, but no callback found at id %d\n", i) + irc.eventsMutex.Unlock() return false } + irc.eventsMutex.Unlock() irc.Log.Println("Event not found") return false } @@ -46,10 +55,14 @@ func (irc *Connection) RemoveCallback(eventcode string, i int) bool { func (irc *Connection) ClearCallback(eventcode string) bool { eventcode = strings.ToUpper(eventcode) - if _, ok := irc.events[eventcode]; ok { + irc.eventsMutex.Lock() + _, ok := irc.events[eventcode] + if ok { irc.events[eventcode] = make(map[int]func(*Event)) + irc.eventsMutex.Unlock() return true } + irc.eventsMutex.Unlock() irc.Log.Println("Event not found") return false @@ -59,7 +72,10 @@ func (irc *Connection) ClearCallback(eventcode string) bool { func (irc *Connection) ReplaceCallback(eventcode string, i int, callback func(*Event)) { eventcode = strings.ToUpper(eventcode) - if event, ok := irc.events[eventcode]; ok { + irc.eventsMutex.Lock() + event, ok := irc.events[eventcode] + irc.eventsMutex.Unlock() + if ok { if _, ok := event[i]; ok { event[i] = callback return @@ -109,7 +125,10 @@ func (irc *Connection) RunCallbacks(event *Event) { event.Arguments[len(event.Arguments)-1] = msg } - if callbacks, ok := irc.events[event.Code]; ok { + irc.eventsMutex.Lock() + callbacks, ok := irc.events[event.Code] + irc.eventsMutex.Unlock() + if ok { if irc.VerboseCallbackHandler { irc.Log.Printf("%v (%v) >> %#v\n", event.Code, len(callbacks), event) } @@ -121,12 +140,15 @@ func (irc *Connection) RunCallbacks(event *Event) { irc.Log.Printf("%v (0) >> %#v\n", event.Code, event) } - if callbacks, ok := irc.events["*"]; ok { + irc.eventsMutex.Lock() + allcallbacks, ok := irc.events["*"] + irc.eventsMutex.Unlock() + if ok { if irc.VerboseCallbackHandler { irc.Log.Printf("%v (0) >> %#v\n", event.Code, event) } - for _, callback := range callbacks { + for _, callback := range allcallbacks { callback(event) } } @@ -136,9 +158,6 @@ func (irc *Connection) RunCallbacks(event *Event) { func (irc *Connection) setupCallbacks() { irc.events = make(map[string]map[int]func(*Event)) - //Handle error events. - irc.AddCallback("ERROR", func(e *Event) { irc.Disconnect() }) - //Handle ping events irc.AddCallback("PING", func(e *Event) { irc.SendRaw("PONG :" + e.Message()) }) diff --git a/irc_parse_test.go b/irc_parse_test.go new file mode 100644 index 0000000..447a590 --- /dev/null +++ b/irc_parse_test.go @@ -0,0 +1,47 @@ +package irc + +import ( + "fmt" + "testing" +) + +func checkResult(t *testing.T, event *Event) { + if event.Nick != "nick" { + t.Fatal("Parse failed: nick") + } + if event.User != "~user" { + t.Fatal("Parse failed: user") + } + if event.Code != "PRIVMSG" { + t.Fatal("Parse failed: code") + } + if event.Arguments[0] != "#channel" { + t.Fatal("Parse failed: channel") + } + if event.Arguments[1] != "message text" { + t.Fatal("Parse failed: message") + } +} + +func TestParse(t *testing.T) { + event, err := parseToEvent(":nick!~user@host PRIVMSG #channel :message text") + if err != nil { + t.Fatal("Parse PRIVMSG failed") + } + checkResult(t, event) +} + +func TestParseTags(t *testing.T) { + event, err := parseToEvent("@tag;+tag2=raw+:=,escaped\\:\\s\\\\ :nick!~user@host PRIVMSG #channel :message text") + if err != nil { + t.Fatal("Parse PRIVMSG with tags failed") + } + checkResult(t, event) + fmt.Printf("%s", event.Tags) + if _, ok := event.Tags["tag"]; !ok { + t.Fatal("Parsing value-less tag failed") + } + if event.Tags["+tag2"] != "raw+:=,escaped; \\" { + t.Fatal("Parsing tag failed") + } +} diff --git a/irc_sasl.go b/irc_sasl.go index e5ff9e3..fd9df7b 100644 --- a/irc_sasl.go +++ b/irc_sasl.go @@ -43,7 +43,6 @@ func (irc *Connection) setupSASLCallbacks(result chan<- *SASLResult) { result <- &SASLResult{true, errors.New(e.Arguments[1])} }) irc.AddCallback("903", func(e *Event) { - irc.SendRaw("CAP END") result <- &SASLResult{false, nil} }) irc.AddCallback("904", func(e *Event) { diff --git a/irc_sasl_test.go b/irc_sasl_test.go index 6d22736..3c2fe47 100644 --- a/irc_sasl_test.go +++ b/irc_sasl_test.go @@ -34,7 +34,7 @@ func TestConnectionSASL(t *testing.T) { err := irccon.Connect(SASLServer) if err != nil { - t.Fatal("SASL failed") + t.Fatalf("SASL failed: %s", err) } irccon.Loop() } diff --git a/irc_struct.go b/irc_struct.go index b04a6a8..3e51bb8 100644 --- a/irc_struct.go +++ b/irc_struct.go @@ -15,21 +15,26 @@ import ( type Connection struct { sync.Mutex sync.WaitGroup - Debug bool - Error chan error - Password string - UseTLS bool - UseSASL bool - SASLLogin string - SASLPassword string - SASLMech string - WebIRC string - TLSConfig *tls.Config - Version string - Timeout time.Duration - PingFreq time.Duration - KeepAlive time.Duration - Server string + Debug bool + Error chan error + WebIRC string + Password string + UseTLS bool + UseSASL bool + RequestCaps []string + AcknowledgedCaps []string + SASLLogin string + SASLPassword string + SASLMech string + TLSConfig *tls.Config + Version string + Timeout time.Duration + PingFreq time.Duration + KeepAlive time.Duration + Server string + + RealName string // The real name we want to display. + // If zero-value defaults to the user. socket net.Conn pwrite chan string @@ -40,9 +45,11 @@ type Connection struct { user string registered bool events map[string]map[int]func(*Event) + eventsMutex sync.Mutex - QuitMessage string - lastMessage time.Time + QuitMessage string + lastMessage time.Time + lastMessageMutex sync.Mutex VerboseCallbackHandler bool Log *log.Logger @@ -60,6 +67,7 @@ type Event struct { Source string // User string // Arguments []string + Tags map[string]string Connection *Connection }