// Package ratelimit provides per-IP rate limiting for HTTP endpoints. package ratelimit import ( "sync" "time" "golang.org/x/time/rate" ) const ( // DefaultRate is the default number of allowed requests per second. DefaultRate = 1.0 // DefaultBurst is the default maximum burst size. DefaultBurst = 5 // DefaultSweepInterval controls how often stale entries are pruned. DefaultSweepInterval = 10 * time.Minute // DefaultEntryTTL is how long an unused entry lives before eviction. DefaultEntryTTL = 15 * time.Minute ) // entry tracks a per-IP rate limiter and when it was last used. type entry struct { limiter *rate.Limiter lastSeen time.Time } // Limiter manages per-key rate limiters with automatic cleanup // of stale entries. type Limiter struct { mu sync.Mutex entries map[string]*entry rate rate.Limit burst int entryTTL time.Duration stopCh chan struct{} } // New creates a new per-key rate Limiter. // The ratePerSec parameter sets how many requests per second are // allowed per key. The burst parameter sets the maximum number of // requests that can be made in a single burst. func New(ratePerSec float64, burst int) *Limiter { limiter := &Limiter{ mu: sync.Mutex{}, entries: make(map[string]*entry), rate: rate.Limit(ratePerSec), burst: burst, entryTTL: DefaultEntryTTL, stopCh: make(chan struct{}), } go limiter.sweepLoop() return limiter } // Allow reports whether a request from the given key should be // allowed. It consumes one token from the key's rate limiter. func (l *Limiter) Allow(key string) bool { l.mu.Lock() ent, exists := l.entries[key] if !exists { ent = &entry{ limiter: rate.NewLimiter(l.rate, l.burst), lastSeen: time.Now(), } l.entries[key] = ent } else { ent.lastSeen = time.Now() } l.mu.Unlock() return ent.limiter.Allow() } // Stop terminates the background sweep goroutine. func (l *Limiter) Stop() { close(l.stopCh) } // Len returns the number of tracked keys (for testing). func (l *Limiter) Len() int { l.mu.Lock() defer l.mu.Unlock() return len(l.entries) } // sweepLoop periodically removes entries that haven't been seen // within the TTL. func (l *Limiter) sweepLoop() { ticker := time.NewTicker(DefaultSweepInterval) defer ticker.Stop() for { select { case <-ticker.C: l.sweep() case <-l.stopCh: return } } } // sweep removes stale entries. func (l *Limiter) sweep() { l.mu.Lock() defer l.mu.Unlock() cutoff := time.Now().Add(-l.entryTTL) for key, ent := range l.entries { if ent.lastSeen.Before(cutoff) { delete(l.entries, key) } } }