diff --git a/internal/middleware/bearer_auth_test.go b/internal/middleware/bearer_auth_test.go new file mode 100644 index 0000000..30bc332 --- /dev/null +++ b/internal/middleware/bearer_auth_test.go @@ -0,0 +1,160 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/fx" + + "git.eeqj.de/sneak/upaas/internal/config" + "git.eeqj.de/sneak/upaas/internal/database" + "git.eeqj.de/sneak/upaas/internal/globals" + "git.eeqj.de/sneak/upaas/internal/logger" + "git.eeqj.de/sneak/upaas/internal/middleware" + "git.eeqj.de/sneak/upaas/internal/models" + "git.eeqj.de/sneak/upaas/internal/service/auth" +) + +// setupMiddleware creates a Middleware with a real SQLite database for +// integration testing. +func setupMiddleware(t *testing.T) (*middleware.Middleware, *auth.Service, *database.Database) { + t.Helper() + + tmpDir := t.TempDir() + + globals.SetAppname("upaas-test") + globals.SetVersion("test") + + globalsInst, err := globals.New(fx.Lifecycle(nil)) + require.NoError(t, err) + + loggerInst, err := logger.New( + fx.Lifecycle(nil), + logger.Params{Globals: globalsInst}, + ) + require.NoError(t, err) + + cfg := &config.Config{ + Port: 8080, + DataDir: tmpDir, + SessionSecret: "test-secret-key-at-least-32-chars!!", + } + _ = filepath.Join(tmpDir, "upaas.db") + + dbInst, err := database.New(fx.Lifecycle(nil), database.Params{ + Logger: loggerInst, + Config: cfg, + }) + require.NoError(t, err) + + authSvc, err := auth.New(fx.Lifecycle(nil), auth.ServiceParams{ + Logger: loggerInst, + Config: cfg, + Database: dbInst, + }) + require.NoError(t, err) + + mw, err := middleware.New(fx.Lifecycle(nil), middleware.Params{ + Logger: loggerInst, + Globals: globalsInst, + Config: cfg, + Auth: authSvc, + Database: dbInst, + }) + require.NoError(t, err) + + return mw, authSvc, dbInst +} + +func TestAPISessionAuth_BearerTokenSetsUserContext(t *testing.T) { + t.Parallel() + + mw, authSvc, dbInst := setupMiddleware(t) + ctx := t.Context() + + // Create a user. + user, err := authSvc.CreateUser(ctx, "testuser", "password123") + require.NoError(t, err) + require.NotNil(t, user) + + // Create an API token for the user. + rawToken, err := models.GenerateToken() + require.NoError(t, err) + + tokenHash := database.HashAPIToken(rawToken) + apiToken := models.NewAPIToken(dbInst) + apiToken.UserID = user.ID + apiToken.Name = "test-token" + apiToken.TokenHash = tokenHash + + err = apiToken.Save(ctx) + require.NoError(t, err) + + // Build a handler behind APISessionAuth that checks user context. + var gotUser *models.User + + var getUserErr error + + handler := mw.APISessionAuth()(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + gotUser, getUserErr = authSvc.GetCurrentUser(r.Context(), r) + + w.WriteHeader(http.StatusOK) + }, + )) + + // Make request with bearer token. + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + req.Header.Set("Authorization", "Bearer "+rawToken) + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + require.NoError(t, getUserErr) + require.NotNil(t, gotUser, "GetCurrentUser should return the user for bearer auth") + assert.Equal(t, user.ID, gotUser.ID) + assert.Equal(t, "testuser", gotUser.Username) +} + +func TestAPISessionAuth_NoBearerTokenReturns401(t *testing.T) { + t.Parallel() + + mw, _, _ := setupMiddleware(t) + + handler := mw.APISessionAuth()(http.HandlerFunc( + func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestAPISessionAuth_InvalidBearerTokenReturns401(t *testing.T) { + t.Parallel() + + mw, _, _ := setupMiddleware(t) + + handler := mw.APISessionAuth()(http.HandlerFunc( + func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 8b36aa6..607fe53 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -386,8 +386,8 @@ func (m *Middleware) APISessionAuth() func(http.Handler) http.Handler { request *http.Request, ) { // Try Bearer token first. - if m.tryBearerAuth(request) { - next.ServeHTTP(writer, request) + if authedReq, ok := m.tryBearerAuth(request); ok { + next.ServeHTTP(writer, authedReq) return } @@ -457,16 +457,19 @@ func (m *Middleware) SetupRequired() func(http.Handler) http.Handler { } // tryBearerAuth checks for a valid Bearer token in the -// Authorization header. -func (m *Middleware) tryBearerAuth(request *http.Request) bool { +// Authorization header. On success it returns a new request +// with the authenticated user set on the context. +func (m *Middleware) tryBearerAuth( + request *http.Request, +) (*http.Request, bool) { authHeader := request.Header.Get("Authorization") if !strings.HasPrefix(authHeader, bearerPrefix) { - return false + return request, false } rawToken := strings.TrimPrefix(authHeader, bearerPrefix) if rawToken == "" { - return false + return request, false } tokenHash := database.HashAPIToken(rawToken) @@ -475,15 +478,26 @@ func (m *Middleware) tryBearerAuth(request *http.Request) bool { request.Context(), m.params.Database, tokenHash, ) if err != nil || apiToken == nil { - return false + return request, false } if apiToken.IsExpired() { - return false + return request, false + } + + // Look up the user associated with the token. + user, err := models.FindUser( + request.Context(), m.params.Database, apiToken.UserID, + ) + if err != nil || user == nil { + return request, false } // Update last_used_at (best effort). _ = apiToken.TouchLastUsed(request.Context()) - return true + // Set the authenticated user on the request context. + ctx := auth.ContextWithUser(request.Context(), user) + + return request.WithContext(ctx), true } diff --git a/internal/service/auth/auth.go b/internal/service/auth/auth.go index 726c2c0..db654b2 100644 --- a/internal/service/auth/auth.go +++ b/internal/service/auth/auth.go @@ -26,6 +26,21 @@ const ( sessionUserID = "user_id" ) +// contextKeyUser is the context key for storing the authenticated user. +type contextKeyUser struct{} + +// ContextWithUser returns a new context with the given user attached. +func ContextWithUser(ctx context.Context, user *models.User) context.Context { + return context.WithValue(ctx, contextKeyUser{}, user) +} + +// UserFromContext retrieves the user from the context, if set. +func UserFromContext(ctx context.Context) *models.User { + user, _ := ctx.Value(contextKeyUser{}).(*models.User) + + return user +} + // Argon2 parameters. const ( argonTime = 1 @@ -239,6 +254,11 @@ func (svc *Service) GetCurrentUser( ctx context.Context, request *http.Request, ) (*models.User, error) { + // Check context first (set by bearer token auth). + if user := UserFromContext(ctx); user != nil { + return user, nil + } + session, sessionErr := svc.store.Get(request, sessionName) if sessionErr != nil { // Session error means no user - this is not an error condition