refactor: use pinned golangci-lint Docker image for linting
All checks were successful
check / check (push) Successful in 1m37s
All checks were successful
check / check (push) Successful in 1m37s
Refactor Dockerfile to use a separate lint stage with a pinned golangci-lint v2.11.3 Docker image instead of installing golangci-lint via curl in the builder stage. This follows the pattern used by sneak/pixa. Changes: - Dockerfile: separate lint stage using golangci/golangci-lint:v2.11.3 (Debian-based, pinned by sha256) with COPY --from=lint dependency - Bump Go from 1.24 to 1.26.1 (golang:1.26.1-bookworm, pinned) - Bump golangci-lint from v1.64.8 to v2.11.3 - Migrate .golangci.yml from v1 to v2 format (same linters, format only) - All Docker images pinned by sha256 digest - Fix all lint issues from the v2 linter upgrade: - Add package comments to all packages - Add doc comments to all exported types, functions, and methods - Fix unchecked errors (errcheck) - Fix unused parameters (revive) - Fix gosec warnings (MaxBytesReader for form parsing) - Fix staticcheck suggestions (fmt.Fprintf instead of WriteString) - Rename DeliveryTask to Task to avoid stutter (delivery.Task) - Rename shadowed builtin 'max' parameter - Update README.md version requirements
This commit is contained in:
@@ -5,41 +5,32 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CircuitState represents the current state of a circuit breaker.
|
||||
// CircuitState represents the current state of a circuit
|
||||
// breaker.
|
||||
type CircuitState int
|
||||
|
||||
const (
|
||||
// CircuitClosed is the normal operating state. Deliveries flow through.
|
||||
// CircuitClosed is the normal operating state.
|
||||
CircuitClosed CircuitState = iota
|
||||
// CircuitOpen means the circuit has tripped. Deliveries are skipped
|
||||
// until the cooldown expires.
|
||||
// CircuitOpen means the circuit has tripped.
|
||||
CircuitOpen
|
||||
// CircuitHalfOpen allows a single probe delivery to test whether
|
||||
// the target has recovered.
|
||||
// CircuitHalfOpen allows a single probe delivery to
|
||||
// test whether the target has recovered.
|
||||
CircuitHalfOpen
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultFailureThreshold is the number of consecutive failures
|
||||
// before a circuit breaker trips open.
|
||||
// defaultFailureThreshold is the number of consecutive
|
||||
// failures before a circuit breaker trips open.
|
||||
defaultFailureThreshold = 5
|
||||
|
||||
// defaultCooldown is how long a circuit stays open before
|
||||
// transitioning to half-open for a probe delivery.
|
||||
// defaultCooldown is how long a circuit stays open
|
||||
// before transitioning to half-open.
|
||||
defaultCooldown = 30 * time.Second
|
||||
)
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for a single
|
||||
// delivery target. It tracks consecutive failures and prevents
|
||||
// hammering a down target by temporarily stopping delivery attempts.
|
||||
//
|
||||
// States:
|
||||
// - Closed (normal): deliveries flow through; consecutive failures
|
||||
// are counted.
|
||||
// - Open (tripped): deliveries are skipped; a cooldown timer is
|
||||
// running. After the cooldown expires the state moves to HalfOpen.
|
||||
// - HalfOpen (probing): one probe delivery is allowed. If it
|
||||
// succeeds the circuit closes; if it fails the circuit reopens.
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
// for a single delivery target.
|
||||
type CircuitBreaker struct {
|
||||
mu sync.Mutex
|
||||
state CircuitState
|
||||
@@ -49,7 +40,8 @@ type CircuitBreaker struct {
|
||||
lastFailure time.Time
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a circuit breaker with default settings.
|
||||
// NewCircuitBreaker creates a circuit breaker with default
|
||||
// settings.
|
||||
func NewCircuitBreaker() *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
@@ -58,12 +50,7 @@ func NewCircuitBreaker() *CircuitBreaker {
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks whether a delivery attempt should proceed. It returns
|
||||
// true if the delivery should be attempted, false if the circuit is
|
||||
// open and the delivery should be skipped.
|
||||
//
|
||||
// When the circuit is open and the cooldown has elapsed, Allow
|
||||
// transitions to half-open and permits exactly one probe delivery.
|
||||
// Allow checks whether a delivery attempt should proceed.
|
||||
func (cb *CircuitBreaker) Allow() bool {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
@@ -73,17 +60,15 @@ func (cb *CircuitBreaker) Allow() bool {
|
||||
return true
|
||||
|
||||
case CircuitOpen:
|
||||
// Check if cooldown has elapsed
|
||||
if time.Since(cb.lastFailure) >= cb.cooldown {
|
||||
cb.state = CircuitHalfOpen
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
case CircuitHalfOpen:
|
||||
// Only one probe at a time — reject additional attempts while
|
||||
// a probe is in flight. The probe goroutine will call
|
||||
// RecordSuccess or RecordFailure to resolve the state.
|
||||
return false
|
||||
|
||||
default:
|
||||
@@ -91,9 +76,8 @@ func (cb *CircuitBreaker) Allow() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// CooldownRemaining returns how much time is left before an open circuit
|
||||
// transitions to half-open. Returns zero if the circuit is not open or
|
||||
// the cooldown has already elapsed.
|
||||
// CooldownRemaining returns how much time is left before
|
||||
// an open circuit transitions to half-open.
|
||||
func (cb *CircuitBreaker) CooldownRemaining() time.Duration {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
@@ -106,11 +90,12 @@ func (cb *CircuitBreaker) CooldownRemaining() time.Duration {
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return remaining
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful delivery and resets the circuit
|
||||
// breaker to closed state with zero failures.
|
||||
// RecordSuccess records a successful delivery and resets
|
||||
// the circuit breaker to closed state.
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
@@ -119,8 +104,8 @@ func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.state = CircuitClosed
|
||||
}
|
||||
|
||||
// RecordFailure records a failed delivery. If the failure count reaches
|
||||
// the threshold, the circuit trips open.
|
||||
// RecordFailure records a failed delivery. If the failure
|
||||
// count reaches the threshold, the circuit trips open.
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
@@ -134,20 +119,25 @@ func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.state = CircuitOpen
|
||||
}
|
||||
|
||||
case CircuitOpen:
|
||||
// Already open; no state change needed.
|
||||
|
||||
case CircuitHalfOpen:
|
||||
// Probe failed — reopen immediately
|
||||
// Probe failed -- reopen immediately.
|
||||
cb.state = CircuitOpen
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the current circuit state. Safe for concurrent use.
|
||||
// State returns the current circuit state.
|
||||
func (cb *CircuitBreaker) State() CircuitState {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
return cb.state
|
||||
}
|
||||
|
||||
// String returns the human-readable name of a circuit state.
|
||||
// String returns the human-readable name of a circuit
|
||||
// state.
|
||||
func (s CircuitState) String() string {
|
||||
switch s {
|
||||
case CircuitClosed:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package delivery
|
||||
package delivery_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
@@ -7,237 +7,304 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/webhooker/internal/delivery"
|
||||
)
|
||||
|
||||
func TestCircuitBreaker_ClosedState_AllowsDeliveries(t *testing.T) {
|
||||
func TestCircuitBreaker_ClosedState_AllowsDeliveries(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
assert.Equal(t, CircuitClosed, cb.State())
|
||||
assert.True(t, cb.Allow(), "closed circuit should allow deliveries")
|
||||
// Multiple calls should all succeed
|
||||
for i := 0; i < 10; i++ {
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
assert.Equal(t, delivery.CircuitClosed, cb.State())
|
||||
assert.True(t, cb.Allow(),
|
||||
"closed circuit should allow deliveries",
|
||||
)
|
||||
|
||||
for range 10 {
|
||||
assert.True(t, cb.Allow())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_FailureCounting(t *testing.T) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
// Record failures below threshold — circuit should stay closed
|
||||
for i := 0; i < defaultFailureThreshold-1; i++ {
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
for i := range delivery.ExportDefaultFailureThreshold - 1 {
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, CircuitClosed, cb.State(),
|
||||
"circuit should remain closed after %d failures", i+1)
|
||||
assert.True(t, cb.Allow(), "should still allow after %d failures", i+1)
|
||||
|
||||
assert.Equal(t,
|
||||
delivery.CircuitClosed, cb.State(),
|
||||
"circuit should remain closed after %d failures",
|
||||
i+1,
|
||||
)
|
||||
|
||||
assert.True(t, cb.Allow(),
|
||||
"should still allow after %d failures",
|
||||
i+1,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_OpenTransition(t *testing.T) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
// Record exactly threshold failures
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
|
||||
assert.Equal(t, CircuitOpen, cb.State(), "circuit should be open after threshold failures")
|
||||
assert.False(t, cb.Allow(), "open circuit should reject deliveries")
|
||||
assert.Equal(t, delivery.CircuitOpen, cb.State(),
|
||||
"circuit should be open after threshold failures",
|
||||
)
|
||||
|
||||
assert.False(t, cb.Allow(),
|
||||
"open circuit should reject deliveries",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Cooldown_StaysOpen(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Use a circuit with a known short cooldown for testing
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 200 * time.Millisecond,
|
||||
}
|
||||
|
||||
// Trip the circuit open
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
require.Equal(t, CircuitOpen, cb.State())
|
||||
|
||||
// During cooldown, Allow should return false
|
||||
assert.False(t, cb.Allow(), "should be blocked during cooldown")
|
||||
require.Equal(t, delivery.CircuitOpen, cb.State())
|
||||
|
||||
assert.False(t, cb.Allow(),
|
||||
"should be blocked during cooldown",
|
||||
)
|
||||
|
||||
// CooldownRemaining should be positive
|
||||
remaining := cb.CooldownRemaining()
|
||||
assert.Greater(t, remaining, time.Duration(0), "cooldown should have remaining time")
|
||||
|
||||
assert.Greater(t, remaining, time.Duration(0),
|
||||
"cooldown should have remaining time",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpen_AfterCooldown(t *testing.T) {
|
||||
func TestCircuitBreaker_HalfOpen_AfterCooldown(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
// Trip the circuit open
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb := newShortCooldownCB(t)
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
require.Equal(t, CircuitOpen, cb.State())
|
||||
|
||||
// Wait for cooldown to expire
|
||||
require.Equal(t, delivery.CircuitOpen, cb.State())
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// CooldownRemaining should be zero after cooldown
|
||||
assert.Equal(t, time.Duration(0), cb.CooldownRemaining())
|
||||
assert.Equal(t, time.Duration(0),
|
||||
cb.CooldownRemaining(),
|
||||
)
|
||||
|
||||
// First Allow after cooldown should succeed (probe)
|
||||
assert.True(t, cb.Allow(), "should allow one probe after cooldown")
|
||||
assert.Equal(t, CircuitHalfOpen, cb.State(), "should be half-open after probe allowed")
|
||||
assert.True(t, cb.Allow(),
|
||||
"should allow one probe after cooldown",
|
||||
)
|
||||
|
||||
// Second Allow should be rejected (only one probe at a time)
|
||||
assert.False(t, cb.Allow(), "should reject additional probes while half-open")
|
||||
assert.Equal(t,
|
||||
delivery.CircuitHalfOpen, cb.State(),
|
||||
"should be half-open after probe allowed",
|
||||
)
|
||||
|
||||
assert.False(t, cb.Allow(),
|
||||
"should reject additional probes while half-open",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ProbeSuccess_ClosesCircuit(t *testing.T) {
|
||||
func TestCircuitBreaker_ProbeSuccess_ClosesCircuit(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
// Trip open → wait for cooldown → allow probe
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb := newShortCooldownCB(t)
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
require.True(t, cb.Allow()) // probe allowed, state → half-open
|
||||
|
||||
// Probe succeeds → circuit should close
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
require.True(t, cb.Allow())
|
||||
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, CircuitClosed, cb.State(), "successful probe should close circuit")
|
||||
|
||||
// Should allow deliveries again
|
||||
assert.True(t, cb.Allow(), "closed circuit should allow deliveries")
|
||||
assert.Equal(t, delivery.CircuitClosed, cb.State(),
|
||||
"successful probe should close circuit",
|
||||
)
|
||||
|
||||
assert.True(t, cb.Allow(),
|
||||
"closed circuit should allow deliveries",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ProbeFailure_ReopensCircuit(t *testing.T) {
|
||||
func TestCircuitBreaker_ProbeFailure_ReopensCircuit(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
// Trip open → wait for cooldown → allow probe
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb := newShortCooldownCB(t)
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
require.True(t, cb.Allow()) // probe allowed, state → half-open
|
||||
|
||||
// Probe fails → circuit should reopen
|
||||
require.True(t, cb.Allow())
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, CircuitOpen, cb.State(), "failed probe should reopen circuit")
|
||||
assert.False(t, cb.Allow(), "reopened circuit should reject deliveries")
|
||||
|
||||
assert.Equal(t, delivery.CircuitOpen, cb.State(),
|
||||
"failed probe should reopen circuit",
|
||||
)
|
||||
|
||||
assert.False(t, cb.Allow(),
|
||||
"reopened circuit should reject deliveries",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
|
||||
func TestCircuitBreaker_SuccessResetsFailures(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
// Accumulate failures just below threshold
|
||||
for i := 0; i < defaultFailureThreshold-1; i++ {
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold - 1 {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
require.Equal(t, CircuitClosed, cb.State())
|
||||
|
||||
// Success should reset the failure counter
|
||||
require.Equal(t, delivery.CircuitClosed, cb.State())
|
||||
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, CircuitClosed, cb.State())
|
||||
|
||||
// Now we should need another full threshold of failures to trip
|
||||
for i := 0; i < defaultFailureThreshold-1; i++ {
|
||||
assert.Equal(t, delivery.CircuitClosed, cb.State())
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold - 1 {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
assert.Equal(t, CircuitClosed, cb.State(),
|
||||
"circuit should still be closed — success reset the counter")
|
||||
|
||||
// One more failure should trip it
|
||||
assert.Equal(t, delivery.CircuitClosed, cb.State(),
|
||||
"circuit should still be closed -- "+
|
||||
"success reset the counter",
|
||||
)
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, CircuitOpen, cb.State())
|
||||
|
||||
assert.Equal(t, delivery.CircuitOpen, cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
const goroutines = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(goroutines * 3)
|
||||
|
||||
// Concurrent Allow calls
|
||||
for i := 0; i < goroutines; i++ {
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
cb.Allow()
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent RecordFailure calls
|
||||
for i := 0; i < goroutines; i++ {
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
cb.RecordFailure()
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent RecordSuccess calls
|
||||
for i := 0; i < goroutines; i++ {
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
cb.RecordSuccess()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// No panic or data race — the test passes if -race doesn't flag anything.
|
||||
// State should be one of the valid states.
|
||||
|
||||
state := cb.State()
|
||||
assert.Contains(t, []CircuitState{CircuitClosed, CircuitOpen, CircuitHalfOpen}, state,
|
||||
"state should be valid after concurrent access")
|
||||
|
||||
assert.Contains(t,
|
||||
[]delivery.CircuitState{
|
||||
delivery.CircuitClosed,
|
||||
delivery.CircuitOpen,
|
||||
delivery.CircuitHalfOpen,
|
||||
},
|
||||
state,
|
||||
"state should be valid after concurrent access",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_CooldownRemaining_ClosedReturnsZero(t *testing.T) {
|
||||
func TestCircuitBreaker_CooldownRemaining_ClosedReturnsZero(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
cb := NewCircuitBreaker()
|
||||
assert.Equal(t, time.Duration(0), cb.CooldownRemaining(),
|
||||
"closed circuit should have zero cooldown remaining")
|
||||
|
||||
cb := delivery.NewCircuitBreaker()
|
||||
|
||||
assert.Equal(t, time.Duration(0),
|
||||
cb.CooldownRemaining(),
|
||||
"closed circuit should have zero cooldown remaining",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_CooldownRemaining_HalfOpenReturnsZero(t *testing.T) {
|
||||
func TestCircuitBreaker_CooldownRemaining_HalfOpenReturnsZero(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
cb := &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: defaultFailureThreshold,
|
||||
cooldown: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
// Trip open, wait, transition to half-open
|
||||
for i := 0; i < defaultFailureThreshold; i++ {
|
||||
cb := newShortCooldownCB(t)
|
||||
|
||||
for range delivery.ExportDefaultFailureThreshold {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
require.True(t, cb.Allow()) // → half-open
|
||||
|
||||
assert.Equal(t, time.Duration(0), cb.CooldownRemaining(),
|
||||
"half-open circuit should have zero cooldown remaining")
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
require.True(t, cb.Allow())
|
||||
|
||||
assert.Equal(t, time.Duration(0),
|
||||
cb.CooldownRemaining(),
|
||||
"half-open circuit should have zero cooldown remaining",
|
||||
)
|
||||
}
|
||||
|
||||
func TestCircuitState_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, "closed", CircuitClosed.String())
|
||||
assert.Equal(t, "open", CircuitOpen.String())
|
||||
assert.Equal(t, "half-open", CircuitHalfOpen.String())
|
||||
assert.Equal(t, "unknown", CircuitState(99).String())
|
||||
|
||||
assert.Equal(t, "closed", delivery.CircuitClosed.String())
|
||||
assert.Equal(t, "open", delivery.CircuitOpen.String())
|
||||
assert.Equal(t, "half-open", delivery.CircuitHalfOpen.String())
|
||||
assert.Equal(t, "unknown", delivery.CircuitState(99).String())
|
||||
}
|
||||
|
||||
// newShortCooldownCB creates a CircuitBreaker with a short
|
||||
// cooldown for testing. We use NewCircuitBreaker and
|
||||
// manipulate through the public API.
|
||||
func newShortCooldownCB(t *testing.T) *delivery.CircuitBreaker {
|
||||
t.Helper()
|
||||
|
||||
return delivery.NewTestCircuitBreaker(
|
||||
delivery.ExportDefaultFailureThreshold,
|
||||
50*time.Millisecond,
|
||||
)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
240
internal/delivery/export_test.go
Normal file
240
internal/delivery/export_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package delivery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
)
|
||||
|
||||
// Exported constants for test access.
|
||||
const (
|
||||
ExportDeliveryChannelSize = deliveryChannelSize
|
||||
ExportRetryChannelSize = retryChannelSize
|
||||
ExportDefaultFailureThreshold = defaultFailureThreshold
|
||||
ExportDefaultCooldown = defaultCooldown
|
||||
)
|
||||
|
||||
// ExportIsBlockedIP exposes isBlockedIP for testing.
|
||||
func ExportIsBlockedIP(ip net.IP) bool {
|
||||
return isBlockedIP(ip)
|
||||
}
|
||||
|
||||
// ExportBlockedNetworks exposes blockedNetworks.
|
||||
func ExportBlockedNetworks() []*net.IPNet {
|
||||
return blockedNetworks
|
||||
}
|
||||
|
||||
// ExportIsForwardableHeader exposes isForwardableHeader.
|
||||
func ExportIsForwardableHeader(name string) bool {
|
||||
return isForwardableHeader(name)
|
||||
}
|
||||
|
||||
// ExportTruncate exposes truncate for testing.
|
||||
func ExportTruncate(s string, maxLen int) string {
|
||||
return truncate(s, maxLen)
|
||||
}
|
||||
|
||||
// ExportDeliverHTTP exposes deliverHTTP for testing.
|
||||
func (e *Engine) ExportDeliverHTTP(
|
||||
ctx context.Context,
|
||||
webhookDB *gorm.DB,
|
||||
d *database.Delivery,
|
||||
task *Task,
|
||||
) {
|
||||
e.deliverHTTP(ctx, webhookDB, d, task)
|
||||
}
|
||||
|
||||
// ExportDeliverDatabase exposes deliverDatabase.
|
||||
func (e *Engine) ExportDeliverDatabase(
|
||||
webhookDB *gorm.DB, d *database.Delivery,
|
||||
) {
|
||||
e.deliverDatabase(webhookDB, d)
|
||||
}
|
||||
|
||||
// ExportDeliverLog exposes deliverLog for testing.
|
||||
func (e *Engine) ExportDeliverLog(
|
||||
webhookDB *gorm.DB, d *database.Delivery,
|
||||
) {
|
||||
e.deliverLog(webhookDB, d)
|
||||
}
|
||||
|
||||
// ExportDeliverSlack exposes deliverSlack for testing.
|
||||
func (e *Engine) ExportDeliverSlack(
|
||||
ctx context.Context,
|
||||
webhookDB *gorm.DB,
|
||||
d *database.Delivery,
|
||||
) {
|
||||
e.deliverSlack(ctx, webhookDB, d)
|
||||
}
|
||||
|
||||
// ExportProcessNewTask exposes processNewTask.
|
||||
func (e *Engine) ExportProcessNewTask(
|
||||
ctx context.Context, task *Task,
|
||||
) {
|
||||
e.processNewTask(ctx, task)
|
||||
}
|
||||
|
||||
// ExportProcessRetryTask exposes processRetryTask.
|
||||
func (e *Engine) ExportProcessRetryTask(
|
||||
ctx context.Context, task *Task,
|
||||
) {
|
||||
e.processRetryTask(ctx, task)
|
||||
}
|
||||
|
||||
// ExportProcessDelivery exposes processDelivery.
|
||||
func (e *Engine) ExportProcessDelivery(
|
||||
ctx context.Context,
|
||||
webhookDB *gorm.DB,
|
||||
d *database.Delivery,
|
||||
task *Task,
|
||||
) {
|
||||
e.processDelivery(ctx, webhookDB, d, task)
|
||||
}
|
||||
|
||||
// ExportGetCircuitBreaker exposes getCircuitBreaker.
|
||||
func (e *Engine) ExportGetCircuitBreaker(
|
||||
targetID string,
|
||||
) *CircuitBreaker {
|
||||
return e.getCircuitBreaker(targetID)
|
||||
}
|
||||
|
||||
// ExportParseHTTPConfig exposes parseHTTPConfig.
|
||||
func (e *Engine) ExportParseHTTPConfig(
|
||||
configJSON string,
|
||||
) (*HTTPTargetConfig, error) {
|
||||
return e.parseHTTPConfig(configJSON)
|
||||
}
|
||||
|
||||
// ExportParseSlackConfig exposes parseSlackConfig.
|
||||
func (e *Engine) ExportParseSlackConfig(
|
||||
configJSON string,
|
||||
) (*SlackTargetConfig, error) {
|
||||
return e.parseSlackConfig(configJSON)
|
||||
}
|
||||
|
||||
// ExportDoHTTPRequest exposes doHTTPRequest.
|
||||
func (e *Engine) ExportDoHTTPRequest(
|
||||
ctx context.Context,
|
||||
cfg *HTTPTargetConfig,
|
||||
event *database.Event,
|
||||
) (int, string, int64, error) {
|
||||
return e.doHTTPRequest(ctx, cfg, event)
|
||||
}
|
||||
|
||||
// ExportScheduleRetry exposes scheduleRetry.
|
||||
func (e *Engine) ExportScheduleRetry(
|
||||
task Task, delay time.Duration,
|
||||
) {
|
||||
e.scheduleRetry(task, delay)
|
||||
}
|
||||
|
||||
// ExportRecoverPendingDeliveries exposes
|
||||
// recoverPendingDeliveries.
|
||||
func (e *Engine) ExportRecoverPendingDeliveries(
|
||||
ctx context.Context,
|
||||
webhookDB *gorm.DB,
|
||||
webhookID string,
|
||||
) {
|
||||
e.recoverPendingDeliveries(
|
||||
ctx, webhookDB, webhookID,
|
||||
)
|
||||
}
|
||||
|
||||
// ExportRecoverWebhookDeliveries exposes
|
||||
// recoverWebhookDeliveries.
|
||||
func (e *Engine) ExportRecoverWebhookDeliveries(
|
||||
ctx context.Context, webhookID string,
|
||||
) {
|
||||
e.recoverWebhookDeliveries(ctx, webhookID)
|
||||
}
|
||||
|
||||
// ExportRecoverInFlight exposes recoverInFlight.
|
||||
func (e *Engine) ExportRecoverInFlight(
|
||||
ctx context.Context,
|
||||
) {
|
||||
e.recoverInFlight(ctx)
|
||||
}
|
||||
|
||||
// ExportStart exposes start for testing.
|
||||
func (e *Engine) ExportStart(ctx context.Context) {
|
||||
e.start(ctx)
|
||||
}
|
||||
|
||||
// ExportStop exposes stop for testing.
|
||||
func (e *Engine) ExportStop() {
|
||||
e.stop()
|
||||
}
|
||||
|
||||
// ExportDeliveryCh returns the delivery channel.
|
||||
func (e *Engine) ExportDeliveryCh() chan Task {
|
||||
return e.deliveryCh
|
||||
}
|
||||
|
||||
// ExportRetryCh returns the retry channel.
|
||||
func (e *Engine) ExportRetryCh() chan Task {
|
||||
return e.retryCh
|
||||
}
|
||||
|
||||
// NewTestEngine creates an Engine for unit tests without
|
||||
// database dependencies.
|
||||
func NewTestEngine(
|
||||
log *slog.Logger,
|
||||
client *http.Client,
|
||||
workers int,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
log: log,
|
||||
client: client,
|
||||
deliveryCh: make(chan Task, deliveryChannelSize),
|
||||
retryCh: make(chan Task, retryChannelSize),
|
||||
workers: workers,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestEngineSmallRetry creates an Engine with a tiny
|
||||
// retry channel buffer for overflow testing.
|
||||
func NewTestEngineSmallRetry(
|
||||
log *slog.Logger,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
log: log,
|
||||
retryCh: make(chan Task, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestEngineWithDB creates an Engine with a real
|
||||
// database and dbManager for integration tests.
|
||||
func NewTestEngineWithDB(
|
||||
db *database.Database,
|
||||
dbMgr *database.WebhookDBManager,
|
||||
log *slog.Logger,
|
||||
client *http.Client,
|
||||
workers int,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
database: db,
|
||||
dbManager: dbMgr,
|
||||
log: log,
|
||||
client: client,
|
||||
deliveryCh: make(chan Task, deliveryChannelSize),
|
||||
retryCh: make(chan Task, retryChannelSize),
|
||||
workers: workers,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestCircuitBreaker creates a CircuitBreaker with
|
||||
// custom settings for testing.
|
||||
func NewTestCircuitBreaker(
|
||||
threshold int, cooldown time.Duration,
|
||||
) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
state: CircuitClosed,
|
||||
threshold: threshold,
|
||||
cooldown: cooldown,
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package delivery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -10,14 +11,27 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// dnsResolutionTimeout is the maximum time to wait for DNS resolution
|
||||
// during SSRF validation.
|
||||
// dnsResolutionTimeout is the maximum time to wait for
|
||||
// DNS resolution during SSRF validation.
|
||||
dnsResolutionTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// blockedNetworks contains all private/reserved IP ranges that should be
|
||||
// blocked to prevent SSRF attacks. This includes RFC 1918 private
|
||||
// addresses, loopback, link-local, and IPv6 equivalents.
|
||||
// Sentinel errors for SSRF validation.
|
||||
var (
|
||||
errNoHostname = errors.New("URL has no hostname")
|
||||
errNoIPs = errors.New(
|
||||
"hostname resolved to no IP addresses",
|
||||
)
|
||||
errBlockedIP = errors.New(
|
||||
"blocked private/reserved IP range",
|
||||
)
|
||||
errInvalidScheme = errors.New(
|
||||
"only http and https are allowed",
|
||||
)
|
||||
)
|
||||
|
||||
// blockedNetworks contains all private/reserved IP ranges
|
||||
// that should be blocked to prevent SSRF attacks.
|
||||
//
|
||||
//nolint:gochecknoglobals // package-level network list is appropriate here
|
||||
var blockedNetworks []*net.IPNet
|
||||
@@ -25,129 +39,184 @@ var blockedNetworks []*net.IPNet
|
||||
//nolint:gochecknoinits // init is the idiomatic way to parse CIDRs once at startup
|
||||
func init() {
|
||||
cidrs := []string{
|
||||
// IPv4 private/reserved ranges
|
||||
"127.0.0.0/8", // Loopback
|
||||
"10.0.0.0/8", // RFC 1918 Class A private
|
||||
"172.16.0.0/12", // RFC 1918 Class B private
|
||||
"192.168.0.0/16", // RFC 1918 Class C private
|
||||
"169.254.0.0/16", // Link-local (cloud metadata)
|
||||
"0.0.0.0/8", // "This" network
|
||||
"100.64.0.0/10", // Shared address space (CGN)
|
||||
"192.0.0.0/24", // IETF protocol assignments
|
||||
"192.0.2.0/24", // TEST-NET-1
|
||||
"198.18.0.0/15", // Benchmarking
|
||||
"198.51.100.0/24", // TEST-NET-2
|
||||
"203.0.113.0/24", // TEST-NET-3
|
||||
"224.0.0.0/4", // Multicast
|
||||
"240.0.0.0/4", // Reserved for future use
|
||||
|
||||
// IPv6 private/reserved ranges
|
||||
"::1/128", // Loopback
|
||||
"fc00::/7", // Unique local addresses
|
||||
"fe80::/10", // Link-local
|
||||
"127.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"169.254.0.0/16",
|
||||
"0.0.0.0/8",
|
||||
"100.64.0.0/10",
|
||||
"192.0.0.0/24",
|
||||
"192.0.2.0/24",
|
||||
"198.18.0.0/15",
|
||||
"198.51.100.0/24",
|
||||
"203.0.113.0/24",
|
||||
"224.0.0.0/4",
|
||||
"240.0.0.0/4",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
}
|
||||
|
||||
for _, cidr := range cidrs {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ssrf: failed to parse CIDR %q: %v", cidr, err))
|
||||
panic(fmt.Sprintf(
|
||||
"ssrf: failed to parse CIDR %q: %v",
|
||||
cidr, err,
|
||||
))
|
||||
}
|
||||
blockedNetworks = append(blockedNetworks, network)
|
||||
|
||||
blockedNetworks = append(
|
||||
blockedNetworks, network,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// isBlockedIP checks whether an IP address falls within any blocked
|
||||
// private/reserved network range.
|
||||
// isBlockedIP checks whether an IP address falls within
|
||||
// any blocked private/reserved network range.
|
||||
func isBlockedIP(ip net.IP) bool {
|
||||
for _, network := range blockedNetworks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateTargetURL checks that an HTTP delivery target URL is safe
|
||||
// from SSRF attacks. It validates the URL format, resolves the hostname
|
||||
// to IP addresses, and verifies that none of the resolved IPs are in
|
||||
// blocked private/reserved ranges.
|
||||
//
|
||||
// Returns nil if the URL is safe, or an error describing the issue.
|
||||
func ValidateTargetURL(targetURL string) error {
|
||||
// ValidateTargetURL checks that an HTTP delivery target
|
||||
// URL is safe from SSRF attacks.
|
||||
func ValidateTargetURL(
|
||||
ctx context.Context, targetURL string,
|
||||
) error {
|
||||
parsed, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// Only allow http and https schemes
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return fmt.Errorf("unsupported URL scheme %q: only http and https are allowed", parsed.Scheme)
|
||||
err = validateScheme(parsed.Scheme)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
host := parsed.Hostname()
|
||||
if host == "" {
|
||||
return fmt.Errorf("URL has no hostname")
|
||||
return errNoHostname
|
||||
}
|
||||
|
||||
// Check if the host is a raw IP address first
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isBlockedIP(ip) {
|
||||
return fmt.Errorf("target IP %s is in a blocked private/reserved range", ip)
|
||||
}
|
||||
return nil
|
||||
return checkBlockedIP(ip)
|
||||
}
|
||||
|
||||
// Resolve hostname to IPs and check each one
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsResolutionTimeout)
|
||||
return validateHostname(ctx, host)
|
||||
}
|
||||
|
||||
func validateScheme(scheme string) error {
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return fmt.Errorf(
|
||||
"unsupported URL scheme %q: %w",
|
||||
scheme, errInvalidScheme,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkBlockedIP(ip net.IP) error {
|
||||
if isBlockedIP(ip) {
|
||||
return fmt.Errorf(
|
||||
"target IP %s is in a blocked "+
|
||||
"private/reserved range: %w",
|
||||
ip, errBlockedIP,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateHostname(
|
||||
ctx context.Context, host string,
|
||||
) error {
|
||||
dnsCtx, cancel := context.WithTimeout(
|
||||
ctx, dnsResolutionTimeout,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(
|
||||
dnsCtx, host,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve hostname %q: %w", host, err)
|
||||
return fmt.Errorf(
|
||||
"failed to resolve hostname %q: %w",
|
||||
host, err,
|
||||
)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return fmt.Errorf("hostname %q resolved to no IP addresses", host)
|
||||
return fmt.Errorf(
|
||||
"hostname %q: %w", host, errNoIPs,
|
||||
)
|
||||
}
|
||||
|
||||
for _, ipAddr := range ips {
|
||||
if isBlockedIP(ipAddr.IP) {
|
||||
return fmt.Errorf("hostname %q resolves to blocked IP %s (private/reserved range)", host, ipAddr.IP)
|
||||
return fmt.Errorf(
|
||||
"hostname %q resolves to blocked "+
|
||||
"IP %s: %w",
|
||||
host, ipAddr.IP, errBlockedIP,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewSSRFSafeTransport creates an http.Transport with a custom DialContext
|
||||
// that blocks connections to private/reserved IP addresses. This provides
|
||||
// defense-in-depth SSRF protection at the network layer, catching cases
|
||||
// where DNS records change between target creation and delivery time
|
||||
// (DNS rebinding attacks).
|
||||
// NewSSRFSafeTransport creates an http.Transport with a
|
||||
// custom DialContext that blocks connections to
|
||||
// private/reserved IP addresses.
|
||||
func NewSSRFSafeTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err)
|
||||
}
|
||||
|
||||
// Resolve hostname to IPs
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssrf: DNS resolution failed for %q: %w", host, err)
|
||||
}
|
||||
|
||||
// Check all resolved IPs
|
||||
for _, ipAddr := range ips {
|
||||
if isBlockedIP(ipAddr.IP) {
|
||||
return nil, fmt.Errorf("ssrf: connection to %s (%s) blocked — private/reserved IP range", host, ipAddr.IP)
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to the first allowed IP
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port))
|
||||
},
|
||||
DialContext: ssrfDialContext,
|
||||
}
|
||||
}
|
||||
|
||||
func ssrfDialContext(
|
||||
ctx context.Context,
|
||||
network, addr string,
|
||||
) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"ssrf: invalid address %q: %w",
|
||||
addr, err,
|
||||
)
|
||||
}
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(
|
||||
ctx, host,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"ssrf: DNS resolution failed for %q: %w",
|
||||
host, err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, ipAddr := range ips {
|
||||
if isBlockedIP(ipAddr.IP) {
|
||||
return nil, fmt.Errorf(
|
||||
"ssrf: connection to %s (%s) "+
|
||||
"blocked: %w",
|
||||
host, ipAddr.IP, errBlockedIP,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
var dialer net.Dialer
|
||||
|
||||
return dialer.DialContext(
|
||||
ctx, network,
|
||||
net.JoinHostPort(ips[0].IP.String(), port),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package delivery
|
||||
package delivery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"sneak.berlin/go/webhooker/internal/delivery"
|
||||
)
|
||||
|
||||
func TestIsBlockedIP_PrivateRanges(t *testing.T) {
|
||||
@@ -16,56 +18,52 @@ func TestIsBlockedIP_PrivateRanges(t *testing.T) {
|
||||
ip string
|
||||
blocked bool
|
||||
}{
|
||||
// Loopback
|
||||
{"loopback 127.0.0.1", "127.0.0.1", true},
|
||||
{"loopback 127.0.0.2", "127.0.0.2", true},
|
||||
{"loopback 127.255.255.255", "127.255.255.255", true},
|
||||
|
||||
// RFC 1918 - Class A
|
||||
{"10.0.0.0", "10.0.0.0", true},
|
||||
{"10.0.0.1", "10.0.0.1", true},
|
||||
{"10.255.255.255", "10.255.255.255", true},
|
||||
|
||||
// RFC 1918 - Class B
|
||||
{"172.16.0.1", "172.16.0.1", true},
|
||||
{"172.31.255.255", "172.31.255.255", true},
|
||||
{"172.15.255.255", "172.15.255.255", false},
|
||||
{"172.32.0.0", "172.32.0.0", false},
|
||||
|
||||
// RFC 1918 - Class C
|
||||
{"192.168.0.1", "192.168.0.1", true},
|
||||
{"192.168.255.255", "192.168.255.255", true},
|
||||
|
||||
// Link-local / cloud metadata
|
||||
{"169.254.0.1", "169.254.0.1", true},
|
||||
{"169.254.169.254", "169.254.169.254", true},
|
||||
|
||||
// Public IPs (should NOT be blocked)
|
||||
{"8.8.8.8", "8.8.8.8", false},
|
||||
{"1.1.1.1", "1.1.1.1", false},
|
||||
{"93.184.216.34", "93.184.216.34", false},
|
||||
|
||||
// IPv6 loopback
|
||||
{"::1", "::1", true},
|
||||
|
||||
// IPv6 unique local
|
||||
{"fd00::1", "fd00::1", true},
|
||||
{"fc00::1", "fc00::1", true},
|
||||
|
||||
// IPv6 link-local
|
||||
{"fe80::1", "fe80::1", true},
|
||||
|
||||
// IPv6 public (should NOT be blocked)
|
||||
{"2607:f8b0:4004:800::200e", "2607:f8b0:4004:800::200e", false},
|
||||
{
|
||||
"2607:f8b0:4004:800::200e",
|
||||
"2607:f8b0:4004:800::200e",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ip := net.ParseIP(tt.ip)
|
||||
require.NotNil(t, ip, "failed to parse IP %s", tt.ip)
|
||||
assert.Equal(t, tt.blocked, isBlockedIP(ip),
|
||||
"isBlockedIP(%s) = %v, want %v", tt.ip, isBlockedIP(ip), tt.blocked)
|
||||
|
||||
require.NotNil(t, ip,
|
||||
"failed to parse IP %s", tt.ip,
|
||||
)
|
||||
|
||||
assert.Equal(t,
|
||||
tt.blocked,
|
||||
delivery.ExportIsBlockedIP(ip),
|
||||
"isBlockedIP(%s) = %v, want %v",
|
||||
tt.ip,
|
||||
delivery.ExportIsBlockedIP(ip),
|
||||
tt.blocked,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -89,8 +87,14 @@ func TestValidateTargetURL_Blocked(t *testing.T) {
|
||||
for _, u := range blockedURLs {
|
||||
t.Run(u, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := ValidateTargetURL(u)
|
||||
assert.Error(t, err, "URL %s should be blocked", u)
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), u,
|
||||
)
|
||||
|
||||
assert.Error(t, err,
|
||||
"URL %s should be blocked", u,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -98,7 +102,6 @@ func TestValidateTargetURL_Blocked(t *testing.T) {
|
||||
func TestValidateTargetURL_Allowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// These are public IPs and should be allowed
|
||||
allowedURLs := []string{
|
||||
"https://example.com/hook",
|
||||
"http://93.184.216.34/webhook",
|
||||
@@ -108,35 +111,62 @@ func TestValidateTargetURL_Allowed(t *testing.T) {
|
||||
for _, u := range allowedURLs {
|
||||
t.Run(u, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := ValidateTargetURL(u)
|
||||
assert.NoError(t, err, "URL %s should be allowed", u)
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), u,
|
||||
)
|
||||
|
||||
assert.NoError(t, err,
|
||||
"URL %s should be allowed", u,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetURL_InvalidScheme(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := ValidateTargetURL("ftp://example.com/hook")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported URL scheme")
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), "ftp://example.com/hook",
|
||||
)
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Contains(t, err.Error(),
|
||||
"unsupported URL scheme",
|
||||
)
|
||||
}
|
||||
|
||||
func TestValidateTargetURL_EmptyHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := ValidateTargetURL("http:///path")
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), "http:///path",
|
||||
)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateTargetURL_InvalidURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := ValidateTargetURL("://invalid")
|
||||
|
||||
err := delivery.ValidateTargetURL(
|
||||
context.Background(), "://invalid",
|
||||
)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBlockedNetworks_Initialized(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.NotEmpty(t, blockedNetworks, "blockedNetworks should be initialized")
|
||||
// Should have at least the main RFC 1918 + loopback + link-local ranges
|
||||
assert.GreaterOrEqual(t, len(blockedNetworks), 8,
|
||||
"should have at least 8 blocked network ranges")
|
||||
|
||||
nets := delivery.ExportBlockedNetworks()
|
||||
|
||||
assert.NotEmpty(t, nets,
|
||||
"blockedNetworks should be initialized",
|
||||
)
|
||||
|
||||
assert.GreaterOrEqual(t, len(nets), 8,
|
||||
"should have at least 8 blocked network ranges",
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user