feat: add API token authentication (closes #87) #94
160
internal/middleware/bearer_auth_test.go
Normal file
160
internal/middleware/bearer_auth_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/fx"
|
||||
|
||||
"git.eeqj.de/sneak/upaas/internal/config"
|
||||
"git.eeqj.de/sneak/upaas/internal/database"
|
||||
"git.eeqj.de/sneak/upaas/internal/globals"
|
||||
"git.eeqj.de/sneak/upaas/internal/logger"
|
||||
"git.eeqj.de/sneak/upaas/internal/middleware"
|
||||
"git.eeqj.de/sneak/upaas/internal/models"
|
||||
"git.eeqj.de/sneak/upaas/internal/service/auth"
|
||||
)
|
||||
|
||||
// setupMiddleware creates a Middleware with a real SQLite database for
|
||||
// integration testing.
|
||||
func setupMiddleware(t *testing.T) (*middleware.Middleware, *auth.Service, *database.Database) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
globals.SetAppname("upaas-test")
|
||||
globals.SetVersion("test")
|
||||
|
||||
globalsInst, err := globals.New(fx.Lifecycle(nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
loggerInst, err := logger.New(
|
||||
fx.Lifecycle(nil),
|
||||
logger.Params{Globals: globalsInst},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{
|
||||
Port: 8080,
|
||||
DataDir: tmpDir,
|
||||
SessionSecret: "test-secret-key-at-least-32-chars!!",
|
||||
}
|
||||
_ = filepath.Join(tmpDir, "upaas.db")
|
||||
|
||||
dbInst, err := database.New(fx.Lifecycle(nil), database.Params{
|
||||
Logger: loggerInst,
|
||||
Config: cfg,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
authSvc, err := auth.New(fx.Lifecycle(nil), auth.ServiceParams{
|
||||
Logger: loggerInst,
|
||||
Config: cfg,
|
||||
Database: dbInst,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mw, err := middleware.New(fx.Lifecycle(nil), middleware.Params{
|
||||
Logger: loggerInst,
|
||||
Globals: globalsInst,
|
||||
Config: cfg,
|
||||
Auth: authSvc,
|
||||
Database: dbInst,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return mw, authSvc, dbInst
|
||||
}
|
||||
|
||||
func TestAPISessionAuth_BearerTokenSetsUserContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mw, authSvc, dbInst := setupMiddleware(t)
|
||||
ctx := t.Context()
|
||||
|
||||
// Create a user.
|
||||
user, err := authSvc.CreateUser(ctx, "testuser", "password123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
// Create an API token for the user.
|
||||
rawToken, err := models.GenerateToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenHash := database.HashAPIToken(rawToken)
|
||||
apiToken := models.NewAPIToken(dbInst)
|
||||
apiToken.UserID = user.ID
|
||||
apiToken.Name = "test-token"
|
||||
apiToken.TokenHash = tokenHash
|
||||
|
||||
err = apiToken.Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build a handler behind APISessionAuth that checks user context.
|
||||
var gotUser *models.User
|
||||
|
||||
var getUserErr error
|
||||
|
||||
handler := mw.APISessionAuth()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
gotUser, getUserErr = authSvc.GetCurrentUser(r.Context(), r)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
|
||||
// Make request with bearer token.
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+rawToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NoError(t, getUserErr)
|
||||
require.NotNil(t, gotUser, "GetCurrentUser should return the user for bearer auth")
|
||||
assert.Equal(t, user.ID, gotUser.ID)
|
||||
assert.Equal(t, "testuser", gotUser.Username)
|
||||
}
|
||||
|
||||
func TestAPISessionAuth_NoBearerTokenReturns401(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mw, _, _ := setupMiddleware(t)
|
||||
|
||||
handler := mw.APISessionAuth()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestAPISessionAuth_InvalidBearerTokenReturns401(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mw, _, _ := setupMiddleware(t)
|
||||
|
||||
handler := mw.APISessionAuth()(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
@@ -386,8 +386,8 @@ func (m *Middleware) APISessionAuth() func(http.Handler) http.Handler {
|
||||
request *http.Request,
|
||||
) {
|
||||
// Try Bearer token first.
|
||||
if m.tryBearerAuth(request) {
|
||||
next.ServeHTTP(writer, request)
|
||||
if authedReq, ok := m.tryBearerAuth(request); ok {
|
||||
next.ServeHTTP(writer, authedReq)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -457,16 +457,19 @@ func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// tryBearerAuth checks for a valid Bearer token in the
|
||||
// Authorization header.
|
||||
func (m *Middleware) tryBearerAuth(request *http.Request) bool {
|
||||
// Authorization header. On success it returns a new request
|
||||
// with the authenticated user set on the context.
|
||||
func (m *Middleware) tryBearerAuth(
|
||||
request *http.Request,
|
||||
) (*http.Request, bool) {
|
||||
authHeader := request.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
return false
|
||||
return request, false
|
||||
}
|
||||
|
||||
rawToken := strings.TrimPrefix(authHeader, bearerPrefix)
|
||||
if rawToken == "" {
|
||||
return false
|
||||
return request, false
|
||||
}
|
||||
|
||||
tokenHash := database.HashAPIToken(rawToken)
|
||||
@@ -475,15 +478,26 @@ func (m *Middleware) tryBearerAuth(request *http.Request) bool {
|
||||
request.Context(), m.params.Database, tokenHash,
|
||||
)
|
||||
if err != nil || apiToken == nil {
|
||||
return false
|
||||
return request, false
|
||||
}
|
||||
|
||||
if apiToken.IsExpired() {
|
||||
return false
|
||||
return request, false
|
||||
}
|
||||
|
||||
// Look up the user associated with the token.
|
||||
user, err := models.FindUser(
|
||||
request.Context(), m.params.Database, apiToken.UserID,
|
||||
)
|
||||
if err != nil || user == nil {
|
||||
return request, false
|
||||
}
|
||||
|
||||
// Update last_used_at (best effort).
|
||||
_ = apiToken.TouchLastUsed(request.Context())
|
||||
|
||||
return true
|
||||
// Set the authenticated user on the request context.
|
||||
ctx := auth.ContextWithUser(request.Context(), user)
|
||||
|
||||
return request.WithContext(ctx), true
|
||||
}
|
||||
|
||||
@@ -26,6 +26,21 @@ const (
|
||||
sessionUserID = "user_id"
|
||||
)
|
||||
|
||||
// contextKeyUser is the context key for storing the authenticated user.
|
||||
type contextKeyUser struct{}
|
||||
|
||||
// ContextWithUser returns a new context with the given user attached.
|
||||
func ContextWithUser(ctx context.Context, user *models.User) context.Context {
|
||||
return context.WithValue(ctx, contextKeyUser{}, user)
|
||||
}
|
||||
|
||||
// UserFromContext retrieves the user from the context, if set.
|
||||
func UserFromContext(ctx context.Context) *models.User {
|
||||
user, _ := ctx.Value(contextKeyUser{}).(*models.User)
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
// Argon2 parameters.
|
||||
const (
|
||||
argonTime = 1
|
||||
@@ -239,6 +254,11 @@ func (svc *Service) GetCurrentUser(
|
||||
ctx context.Context,
|
||||
request *http.Request,
|
||||
) (*models.User, error) {
|
||||
// Check context first (set by bearer token auth).
|
||||
if user := UserFromContext(ctx); user != nil {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
session, sessionErr := svc.store.Get(request, sessionName)
|
||||
if sessionErr != nil {
|
||||
// Session error means no user - this is not an error condition
|
||||
|
||||
Reference in New Issue
Block a user