feat: add API token authentication (closes #87) #94
160
internal/middleware/bearer_auth_test.go
Normal file
160
internal/middleware/bearer_auth_test.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/fx"
|
||||||
|
|
||||||
|
"git.eeqj.de/sneak/upaas/internal/config"
|
||||||
|
"git.eeqj.de/sneak/upaas/internal/database"
|
||||||
|
"git.eeqj.de/sneak/upaas/internal/globals"
|
||||||
|
"git.eeqj.de/sneak/upaas/internal/logger"
|
||||||
|
"git.eeqj.de/sneak/upaas/internal/middleware"
|
||||||
|
"git.eeqj.de/sneak/upaas/internal/models"
|
||||||
|
"git.eeqj.de/sneak/upaas/internal/service/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setupMiddleware creates a Middleware with a real SQLite database for
|
||||||
|
// integration testing.
|
||||||
|
func setupMiddleware(t *testing.T) (*middleware.Middleware, *auth.Service, *database.Database) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
globals.SetAppname("upaas-test")
|
||||||
|
globals.SetVersion("test")
|
||||||
|
|
||||||
|
globalsInst, err := globals.New(fx.Lifecycle(nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
loggerInst, err := logger.New(
|
||||||
|
fx.Lifecycle(nil),
|
||||||
|
logger.Params{Globals: globalsInst},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Port: 8080,
|
||||||
|
DataDir: tmpDir,
|
||||||
|
SessionSecret: "test-secret-key-at-least-32-chars!!",
|
||||||
|
}
|
||||||
|
_ = filepath.Join(tmpDir, "upaas.db")
|
||||||
|
|
||||||
|
dbInst, err := database.New(fx.Lifecycle(nil), database.Params{
|
||||||
|
Logger: loggerInst,
|
||||||
|
Config: cfg,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authSvc, err := auth.New(fx.Lifecycle(nil), auth.ServiceParams{
|
||||||
|
Logger: loggerInst,
|
||||||
|
Config: cfg,
|
||||||
|
Database: dbInst,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mw, err := middleware.New(fx.Lifecycle(nil), middleware.Params{
|
||||||
|
Logger: loggerInst,
|
||||||
|
Globals: globalsInst,
|
||||||
|
Config: cfg,
|
||||||
|
Auth: authSvc,
|
||||||
|
Database: dbInst,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return mw, authSvc, dbInst
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPISessionAuth_BearerTokenSetsUserContext(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mw, authSvc, dbInst := setupMiddleware(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
// Create a user.
|
||||||
|
user, err := authSvc.CreateUser(ctx, "testuser", "password123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
|
||||||
|
// Create an API token for the user.
|
||||||
|
rawToken, err := models.GenerateToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tokenHash := database.HashAPIToken(rawToken)
|
||||||
|
apiToken := models.NewAPIToken(dbInst)
|
||||||
|
apiToken.UserID = user.ID
|
||||||
|
apiToken.Name = "test-token"
|
||||||
|
apiToken.TokenHash = tokenHash
|
||||||
|
|
||||||
|
err = apiToken.Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Build a handler behind APISessionAuth that checks user context.
|
||||||
|
var gotUser *models.User
|
||||||
|
|
||||||
|
var getUserErr error
|
||||||
|
|
||||||
|
handler := mw.APISessionAuth()(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotUser, getUserErr = authSvc.GetCurrentUser(r.Context(), r)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Make request with bearer token.
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+rawToken)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.NoError(t, getUserErr)
|
||||||
|
require.NotNil(t, gotUser, "GetCurrentUser should return the user for bearer auth")
|
||||||
|
assert.Equal(t, user.ID, gotUser.ID)
|
||||||
|
assert.Equal(t, "testuser", gotUser.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPISessionAuth_NoBearerTokenReturns401(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mw, _, _ := setupMiddleware(t)
|
||||||
|
|
||||||
|
handler := mw.APISessionAuth()(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPISessionAuth_InvalidBearerTokenReturns401(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mw, _, _ := setupMiddleware(t)
|
||||||
|
|
||||||
|
handler := mw.APISessionAuth()(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||||
|
}
|
||||||
@@ -386,8 +386,8 @@ func (m *Middleware) APISessionAuth() func(http.Handler) http.Handler {
|
|||||||
request *http.Request,
|
request *http.Request,
|
||||||
) {
|
) {
|
||||||
// Try Bearer token first.
|
// Try Bearer token first.
|
||||||
if m.tryBearerAuth(request) {
|
if authedReq, ok := m.tryBearerAuth(request); ok {
|
||||||
next.ServeHTTP(writer, request)
|
next.ServeHTTP(writer, authedReq)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -457,16 +457,19 @@ func (m *Middleware) SetupRequired() func(http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tryBearerAuth checks for a valid Bearer token in the
|
// tryBearerAuth checks for a valid Bearer token in the
|
||||||
// Authorization header.
|
// Authorization header. On success it returns a new request
|
||||||
func (m *Middleware) tryBearerAuth(request *http.Request) bool {
|
// with the authenticated user set on the context.
|
||||||
|
func (m *Middleware) tryBearerAuth(
|
||||||
|
request *http.Request,
|
||||||
|
) (*http.Request, bool) {
|
||||||
authHeader := request.Header.Get("Authorization")
|
authHeader := request.Header.Get("Authorization")
|
||||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||||
return false
|
return request, false
|
||||||
}
|
}
|
||||||
|
|
||||||
rawToken := strings.TrimPrefix(authHeader, bearerPrefix)
|
rawToken := strings.TrimPrefix(authHeader, bearerPrefix)
|
||||||
if rawToken == "" {
|
if rawToken == "" {
|
||||||
return false
|
return request, false
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenHash := database.HashAPIToken(rawToken)
|
tokenHash := database.HashAPIToken(rawToken)
|
||||||
@@ -475,15 +478,26 @@ func (m *Middleware) tryBearerAuth(request *http.Request) bool {
|
|||||||
request.Context(), m.params.Database, tokenHash,
|
request.Context(), m.params.Database, tokenHash,
|
||||||
)
|
)
|
||||||
if err != nil || apiToken == nil {
|
if err != nil || apiToken == nil {
|
||||||
return false
|
return request, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if apiToken.IsExpired() {
|
if apiToken.IsExpired() {
|
||||||
return false
|
return request, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the user associated with the token.
|
||||||
|
user, err := models.FindUser(
|
||||||
|
request.Context(), m.params.Database, apiToken.UserID,
|
||||||
|
)
|
||||||
|
if err != nil || user == nil {
|
||||||
|
return request, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last_used_at (best effort).
|
// Update last_used_at (best effort).
|
||||||
_ = apiToken.TouchLastUsed(request.Context())
|
_ = apiToken.TouchLastUsed(request.Context())
|
||||||
|
|
||||||
return true
|
// Set the authenticated user on the request context.
|
||||||
|
ctx := auth.ContextWithUser(request.Context(), user)
|
||||||
|
|
||||||
|
return request.WithContext(ctx), true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,21 @@ const (
|
|||||||
sessionUserID = "user_id"
|
sessionUserID = "user_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// contextKeyUser is the context key for storing the authenticated user.
|
||||||
|
type contextKeyUser struct{}
|
||||||
|
|
||||||
|
// ContextWithUser returns a new context with the given user attached.
|
||||||
|
func ContextWithUser(ctx context.Context, user *models.User) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyUser{}, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserFromContext retrieves the user from the context, if set.
|
||||||
|
func UserFromContext(ctx context.Context) *models.User {
|
||||||
|
user, _ := ctx.Value(contextKeyUser{}).(*models.User)
|
||||||
|
|
||||||
|
return user
|
||||||
|
}
|
||||||
|
|
||||||
// Argon2 parameters.
|
// Argon2 parameters.
|
||||||
const (
|
const (
|
||||||
argonTime = 1
|
argonTime = 1
|
||||||
@@ -239,6 +254,11 @@ func (svc *Service) GetCurrentUser(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *http.Request,
|
request *http.Request,
|
||||||
) (*models.User, error) {
|
) (*models.User, error) {
|
||||||
|
// Check context first (set by bearer token auth).
|
||||||
|
if user := UserFromContext(ctx); user != nil {
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
session, sessionErr := svc.store.Get(request, sessionName)
|
session, sessionErr := svc.store.Get(request, sessionName)
|
||||||
if sessionErr != nil {
|
if sessionErr != nil {
|
||||||
// Session error means no user - this is not an error condition
|
// Session error means no user - this is not an error condition
|
||||||
|
|||||||
Reference in New Issue
Block a user