diff --git a/irc.go b/irc.go index a745761..45701d3 100644 --- a/irc.go +++ b/irc.go @@ -21,7 +21,6 @@ package irc import ( "bufio" "bytes" - "context" "crypto/tls" "errors" "fmt" @@ -80,10 +79,6 @@ func (irc *Connection) readLoop() { irc.lastMessageMutex.Unlock() event, err := parseToEvent(msg) event.Connection = irc - event.Ctx = context.Background() - if irc.CallbackTimeout != 0 { - event.Ctx, _ = context.WithTimeout(event.Ctx, irc.CallbackTimeout) - } if err == nil { /* XXX: len(args) == 0: args should be empty */ irc.RunCallbacks(event) @@ -236,7 +231,9 @@ func (irc *Connection) Loop() { errChan := irc.ErrorChan() for !irc.isQuitting() { err := <-errChan - close(irc.end) + if irc.end != nil { + close(irc.end) + } irc.Wait() for !irc.isQuitting() { irc.Log.Printf("Error, disconnected: %s\n", err) @@ -400,13 +397,14 @@ func (irc *Connection) Disconnect() { close(irc.end) } + irc.Wait() + irc.end = nil if irc.pwrite != nil { close(irc.pwrite) } - irc.Wait() if irc.socket != nil { irc.socket.Close() } @@ -473,7 +471,7 @@ func (irc *Connection) Connect(server string) error { irc.Log.Printf("Connected to %s (%s)\n", irc.Server, irc.socket.RemoteAddr()) irc.pwrite = make(chan string, 10) - irc.Error = make(chan error, 2) + irc.Error = make(chan error, 10) irc.Add(3) go irc.readLoop() go irc.writeLoop() diff --git a/irc_callback.go b/irc_callback.go index c95a0cd..08aecbe 100644 --- a/irc_callback.go +++ b/irc_callback.go @@ -1,7 +1,7 @@ package irc import ( - "fmt" + "context" "reflect" "runtime" "strconv" @@ -129,17 +129,20 @@ func (irc *Connection) RunCallbacks(event *Event) { } irc.eventsMutex.Lock() - callbacks := []func(*Event){} + callbacks := make(map[int]func(*Event)) eventCallbacks, ok := irc.events[event.Code] + id := 0 if ok { for _, callback := range eventCallbacks { - callbacks = append(callbacks, callback) + callbacks[id] = callback + id++ } } allCallbacks, ok := irc.events["*"] if ok { for _, callback := range allCallbacks { - callbacks = append(callbacks, callback) + callbacks[id] = callback + id++ } } irc.eventsMutex.Unlock() @@ -148,36 +151,42 @@ func (irc *Connection) RunCallbacks(event *Event) { irc.Log.Printf("%v (%v) >> %#v\n", event.Code, len(callbacks), event) } - done := make(chan bool) - possibleLogs := []string{} - for i, callback := range callbacks { - go func(done chan bool) { - callback(event) - done <- true - }(done) - callbackName := getFunctionName(callback) - start := time.Now() + event.Ctx = context.Background() + if irc.CallbackTimeout != 0 { + event.Ctx, _ = context.WithTimeout(event.Ctx, irc.CallbackTimeout) + } + + done := make(chan int) + for id, callback := range callbacks { + go func(id int, done chan<- int, cb func(*Event), event *Event) { + start := time.Now() + cb(event) + select { + case done <- id: + case <-event.Ctx.Done(): // If we timed out, report how long until we eventually finished + irc.Log.Printf("Canceled callback %s finished in %s >> %#v\n", + getFunctionName(cb), + time.Since(start), + event, + ) + } + }(id, done, callback, event) + } + + for len(callbacks) > 0 { select { + case jobID := <-done: + delete(callbacks, jobID) case <-event.Ctx.Done(): // context timed out! - irc.Log.Printf("TIMEOUT: %s timeout expired while executing %s, abandoning remaining callbacks", irc.CallbackTimeout, callbackName) - - // If we timed out let's include context for how long each previous handler took - for _, logItem := range possibleLogs { - irc.Log.Println(logItem) + timedOutCallbacks := []string{} + for _, cb := range callbacks { // Everything left here did not finish + timedOutCallbacks = append(timedOutCallbacks, getFunctionName(cb)) } - irc.Log.Printf("Callback %s ran for %s prior to timeout", callbackName, time.Since(start)) - if len(callbacks) > i { - for _, callback := range callbacks[i+1:] { - irc.Log.Printf("Callback %s did not run", getFunctionName(callback)) - } - } - - // At this point our context has expired and it's not safe to execute anything else, lets bail. + irc.Log.Printf("Timeout while waiting for %d callback(s) to finish (%s)\n", + len(callbacks), + strings.Join(timedOutCallbacks, ", "), + ) return - case <-done: - elapsed := time.Since(start) - logMsg := fmt.Sprintf("Callback %s took %s", getFunctionName(callback), elapsed) - possibleLogs = append(possibleLogs, logMsg) } } } diff --git a/irc_parse_test.go b/irc_parse_test.go index 447a590..cc7c734 100644 --- a/irc_parse_test.go +++ b/irc_parse_test.go @@ -1,7 +1,6 @@ package irc import ( - "fmt" "testing" ) @@ -37,7 +36,7 @@ func TestParseTags(t *testing.T) { t.Fatal("Parse PRIVMSG with tags failed") } checkResult(t, event) - fmt.Printf("%s", event.Tags) + t.Logf("%s", event.Tags) if _, ok := event.Tags["tag"]; !ok { t.Fatal("Parsing value-less tag failed") } diff --git a/irc_sasl_test.go b/irc_sasl_test.go index 3c2fe47..1b363d4 100644 --- a/irc_sasl_test.go +++ b/irc_sasl_test.go @@ -16,6 +16,9 @@ func TestConnectionSASL(t *testing.T) { if SASLLogin == "" { t.Skip("Define SASLLogin and SASLPasword environment varables to test SASL") } + if testing.Short() { + t.Skip("skipping test in short mode.") + } irccon := IRC("go-eventirc", "go-eventirc") irccon.VerboseCallbackHandler = true irccon.Debug = true diff --git a/irc_test.go b/irc_test.go index 4ef2a2b..b2d3065 100644 --- a/irc_test.go +++ b/irc_test.go @@ -3,6 +3,7 @@ package irc import ( "crypto/tls" "math/rand" + "sort" "testing" "time" ) @@ -115,7 +116,7 @@ func TestRemoveCallback(t *testing.T) { results = append(results, <-done) results = append(results, <-done) - if len(results) != 2 || results[0] == 2 || results[1] == 2 { + if !compareResults(results, 1, 3) { t.Error("Callback 2 not removed") } } @@ -138,7 +139,7 @@ func TestWildcardCallback(t *testing.T) { results = append(results, <-done) results = append(results, <-done) - if len(results) != 2 || !(results[0] == 1 && results[1] == 2) { + if !compareResults(results, 1, 2) { t.Error("Wildcard callback not called") } } @@ -164,7 +165,7 @@ func TestClearCallback(t *testing.T) { results = append(results, <-done) results = append(results, <-done) - if len(results) != 2 || !(results[0] == 2 && results[1] == 3) { + if !compareResults(results, 2, 3) { t.Error("Callbacks not cleared") } } @@ -185,6 +186,9 @@ func TestIRCemptyUser(t *testing.T) { } } func TestConnection(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } rand.Seed(time.Now().UnixNano()) ircnick1 := randStr(8) ircnick2 := randStr(8) @@ -266,8 +270,12 @@ func TestConnection(t *testing.T) { } func TestReconnect(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } ircnick1 := randStr(8) irccon := IRC(ircnick1, "IRCTestRe") + irccon.PingFreq = time.Second * 3 debugTest(irccon) connects := 0 @@ -277,11 +285,11 @@ func TestReconnect(t *testing.T) { connects += 1 if connects > 2 { irccon.Privmsgf(channel, "Connection nr %d (test done)\n", connects) - irccon.Quit() + go irccon.Quit() } else { irccon.Privmsgf(channel, "Connection nr %d\n", connects) time.Sleep(100) //Need to let the thraed actually send before closing socket - irccon.Disconnect() + go irccon.Disconnect() } }) @@ -298,6 +306,9 @@ func TestReconnect(t *testing.T) { } func TestConnectionSSL(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } ircnick1 := randStr(8) irccon := IRC(ircnick1, "IRCTestSSL") debugTest(irccon) @@ -333,3 +344,17 @@ func debugTest(irccon *Connection) *Connection { irccon.Debug = debug_tests return irccon } + +func compareResults(received []int, desired ...int) bool { + if len(desired) != len(received) { + return false + } + sort.IntSlice(desired).Sort() + sort.IntSlice(received).Sort() + for i := 0; i < len(desired); i++ { + if desired[i] != received[i] { + return false + } + } + return true +}