feat: implement hashcash proof-of-work for session creation (#63)
ステータスチェックはすべて成功しました
check / check (push) Successful in 1m2s

## Summary

Implement SHA-256-based hashcash proof-of-work for `POST /session` to prevent abuse via rapid session creation.

closes #11

## What Changed

### Server
- **New `internal/hashcash` package**: Validates hashcash stamps (format, difficulty bits, date/expiry, resource, replay prevention via in-memory spent set with TTL pruning)
- **Config**: `NEOIRC_HASHCASH_BITS` env var (default 20, set to 0 to disable)
- **`GET /api/v1/server`**: Now includes `hashcash_bits` field when > 0
- **`POST /api/v1/session`**: Validates `X-Hashcash` header when hashcash is enabled; returns HTTP 402 for missing/invalid stamps

### Clients
- **Web SPA**: Fetches `hashcash_bits` from `/server`, computes stamp using Web Crypto API (`crypto.subtle.digest`) with batched parallelism (1024 hashes/batch), shows "Computing proof-of-work..." feedback
- **CLI (`neoirc-cli`)**: `CreateSession()` auto-fetches server info and computes a valid hashcash stamp when required; new `MintHashcash()` function in the API package

### Documentation
- README updated with full hashcash documentation: stamp format, computing stamps, configuration, difficulty table
- Server info and session creation API docs updated with hashcash fields/headers
- Roadmap updated (hashcash marked as implemented)

## Stamp Format

Standard hashcash: `1:bits:YYMMDD:resource::counter`

The SHA-256 hash of the entire stamp string must have at least `bits` leading zero bits.

## Validation Rules
- Version must be `1`
- Claimed bits ≥ required bits
- Resource must match server name
- Date within 48 hours (not expired, not too far in future)
- SHA-256 hash has required leading zero bits
- Stamp not previously used (replay prevention)

## Testing
- All existing tests pass (hashcash disabled in test config with `HashcashBits: 0`)
- `docker build .` passes (lint + test + build)

<!-- session: agent:sdlc-manager:subagent:f98d712e-8a40-4013-b3d7-588cbff670f4 -->

Co-authored-by: clawbot <clawbot@noreply.git.eeqj.de>
Co-authored-by: clawbot <clawbot@noreply.eeqj.de>
Co-authored-by: user <user@Mac.lan guest wan>
Co-authored-by: Jeffrey Paul <sneak@noreply.example.org>
Reviewed-on: #63
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
このコミットはプルリクエスト #63 でマージされました。
このコミットが含まれているのは:
2026-03-13 00:38:41 +01:00
committed by Jeffrey Paul
コミット 75cecd9803
14個のファイルの変更1795行の追加1008行の削除

333
internal/cli/api/client.go ノーマルファイル
ファイルの表示

@@ -0,0 +1,333 @@
// Package neoircapi provides a client for the neoirc server API.
package neoircapi
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"git.eeqj.de/sneak/neoirc/pkg/irc"
)
const (
httpTimeout = 30 * time.Second
pollExtraTime = 5
httpErrThreshold = 400
)
var errHTTP = errors.New("HTTP error")
// Client wraps HTTP calls to the neoirc server API.
type Client struct {
BaseURL string
Token string
HTTPClient *http.Client
}
// NewClient creates a new API client.
func NewClient(baseURL string) *Client {
return &Client{ //nolint:exhaustruct // Token set after CreateSession
BaseURL: baseURL,
HTTPClient: &http.Client{ //nolint:exhaustruct // defaults fine
Timeout: httpTimeout,
},
}
}
// CreateSession creates a new session on the server.
// If the server requires hashcash proof-of-work, it
// automatically fetches the difficulty and computes a
// valid stamp.
func (client *Client) CreateSession(
nick string,
) (*SessionResponse, error) {
// Fetch server info to check for hashcash requirement.
info, err := client.GetServerInfo()
var hashcashStamp string
if err == nil && info.HashcashBits > 0 {
resource := info.Name
if resource == "" {
resource = "neoirc"
}
hashcashStamp = MintHashcash(info.HashcashBits, resource)
}
data, err := client.do(
http.MethodPost,
"/api/v1/session",
&SessionRequest{Nick: nick, Hashcash: hashcashStamp},
)
if err != nil {
return nil, err
}
var resp SessionResponse
err = json.Unmarshal(data, &resp)
if err != nil {
return nil, fmt.Errorf("decode session: %w", err)
}
client.Token = resp.Token
return &resp, nil
}
// GetState returns the current user state.
func (client *Client) GetState() (*StateResponse, error) {
data, err := client.do(
http.MethodGet, "/api/v1/state", nil,
)
if err != nil {
return nil, err
}
var resp StateResponse
err = json.Unmarshal(data, &resp)
if err != nil {
return nil, fmt.Errorf("decode state: %w", err)
}
return &resp, nil
}
// SendMessage sends a message (any IRC command).
func (client *Client) SendMessage(msg *Message) error {
_, err := client.do(
http.MethodPost, "/api/v1/messages", msg,
)
return err
}
// PollMessages long-polls for new messages.
func (client *Client) PollMessages(
afterID int64,
timeout int,
) (*PollResult, error) {
pollClient := &http.Client{ //nolint:exhaustruct // defaults fine
Timeout: time.Duration(
timeout+pollExtraTime,
) * time.Second,
}
params := url.Values{}
if afterID > 0 {
params.Set(
"after",
strconv.FormatInt(afterID, 10),
)
}
params.Set("timeout", strconv.Itoa(timeout))
path := "/api/v1/messages?" + params.Encode()
request, err := http.NewRequestWithContext(
context.Background(),
http.MethodGet,
client.BaseURL+path,
nil,
)
if err != nil {
return nil, fmt.Errorf("new request: %w", err)
}
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
resp, err := pollClient.Do(request)
if err != nil {
return nil, fmt.Errorf("poll request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read poll body: %w", err)
}
if resp.StatusCode >= httpErrThreshold {
return nil, fmt.Errorf(
"%w %d: %s",
errHTTP, resp.StatusCode, string(data),
)
}
var wrapped MessagesResponse
err = json.Unmarshal(data, &wrapped)
if err != nil {
return nil, fmt.Errorf(
"decode messages: %w", err,
)
}
return &PollResult{
Messages: wrapped.Messages,
LastID: wrapped.LastID,
}, nil
}
// JoinChannel joins a channel.
func (client *Client) JoinChannel(channel string) error {
return client.SendMessage(
&Message{ //nolint:exhaustruct // only command+to needed
Command: irc.CmdJoin, To: channel,
},
)
}
// PartChannel leaves a channel.
func (client *Client) PartChannel(channel string) error {
return client.SendMessage(
&Message{ //nolint:exhaustruct // only command+to needed
Command: irc.CmdPart, To: channel,
},
)
}
// ListChannels returns all channels on the server.
func (client *Client) ListChannels() (
[]Channel, error,
) {
data, err := client.do(
http.MethodGet, "/api/v1/channels", nil,
)
if err != nil {
return nil, err
}
var channels []Channel
err = json.Unmarshal(data, &channels)
if err != nil {
return nil, fmt.Errorf(
"decode channels: %w", err,
)
}
return channels, nil
}
// GetMembers returns members of a channel.
func (client *Client) GetMembers(
channel string,
) ([]string, error) {
name := strings.TrimPrefix(channel, "#")
data, err := client.do(
http.MethodGet,
"/api/v1/channels/"+url.PathEscape(name)+
"/members",
nil,
)
if err != nil {
return nil, err
}
var members []string
err = json.Unmarshal(data, &members)
if err != nil {
return nil, fmt.Errorf(
"unexpected members format: %w", err,
)
}
return members, nil
}
// GetServerInfo returns server info.
func (client *Client) GetServerInfo() (
*ServerInfo, error,
) {
data, err := client.do(
http.MethodGet, "/api/v1/server", nil,
)
if err != nil {
return nil, err
}
var info ServerInfo
err = json.Unmarshal(data, &info)
if err != nil {
return nil, fmt.Errorf(
"decode server info: %w", err,
)
}
return &info, nil
}
func (client *Client) do(
method, path string,
body any,
) ([]byte, error) {
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal: %w", err)
}
bodyReader = bytes.NewReader(data)
}
request, err := http.NewRequestWithContext(
context.Background(),
method,
client.BaseURL+path,
bodyReader,
)
if err != nil {
return nil, fmt.Errorf("request: %w", err)
}
request.Header.Set(
"Content-Type", "application/json",
)
if client.Token != "" {
request.Header.Set(
"Authorization", "Bearer "+client.Token,
)
}
resp, err := client.HTTPClient.Do(request)
if err != nil {
return nil, fmt.Errorf("http: %w", err)
}
defer func() { _ = resp.Body.Close() }()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
if resp.StatusCode >= httpErrThreshold {
return data, fmt.Errorf(
"%w %d: %s",
errHTTP, resp.StatusCode, string(data),
)
}
return data, nil
}

79
internal/cli/api/hashcash.go ノーマルファイル
ファイルの表示

@@ -0,0 +1,79 @@
package neoircapi
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"math/big"
"time"
)
const (
// bitsPerByte is the number of bits in a byte.
bitsPerByte = 8
// fullByteMask is 0xFF, a mask for all bits in a byte.
fullByteMask = 0xFF
// counterSpace is the range for random counter seeds.
counterSpace = 1 << 48
)
// MintHashcash computes a hashcash stamp with the given
// difficulty (leading zero bits) and resource string.
func MintHashcash(bits int, resource string) string {
date := time.Now().UTC().Format("060102")
prefix := fmt.Sprintf(
"1:%d:%s:%s::", bits, date, resource,
)
for {
counter := randomCounter()
stamp := prefix + counter
hash := sha256.Sum256([]byte(stamp))
if hasLeadingZeroBits(hash[:], bits) {
return stamp
}
}
}
// hasLeadingZeroBits checks if hash has at least numBits
// leading zero bits.
func hasLeadingZeroBits(
hash []byte,
numBits int,
) bool {
fullBytes := numBits / bitsPerByte
remainBits := numBits % bitsPerByte
for idx := range fullBytes {
if hash[idx] != 0 {
return false
}
}
if remainBits > 0 && fullBytes < len(hash) {
mask := byte(
fullByteMask << (bitsPerByte - remainBits),
)
if hash[fullBytes]&mask != 0 {
return false
}
}
return true
}
// randomCounter generates a random hex counter string.
func randomCounter() string {
counterVal, err := rand.Int(
rand.Reader, big.NewInt(counterSpace),
)
if err != nil {
// Fallback to timestamp-based counter on error.
return fmt.Sprintf("%x", time.Now().UnixNano())
}
return hex.EncodeToString(counterVal.Bytes())
}

93
internal/cli/api/types.go ノーマルファイル
ファイルの表示

@@ -0,0 +1,93 @@
package neoircapi
import "time"
// SessionRequest is the body for POST /api/v1/session.
type SessionRequest struct {
Nick string `json:"nick"`
Hashcash string `json:"pow_token,omitempty"` //nolint:tagliatelle
}
// SessionResponse is the response from session creation.
type SessionResponse struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
Token string `json:"token"`
}
// StateResponse is the response from GET /api/v1/state.
type StateResponse struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
Channels []string `json:"channels"`
}
// Message represents a neoirc message envelope.
type Message struct {
Command string `json:"command"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Params []string `json:"params,omitempty"`
Body any `json:"body,omitempty"`
ID string `json:"id,omitempty"`
TS string `json:"ts,omitempty"`
Meta any `json:"meta,omitempty"`
}
// BodyLines returns the body as a string slice.
func (m *Message) BodyLines() []string {
switch bodyVal := m.Body.(type) {
case []any:
lines := make([]string, 0, len(bodyVal))
for _, item := range bodyVal {
if str, ok := item.(string); ok {
lines = append(lines, str)
}
}
return lines
case []string:
return bodyVal
default:
return nil
}
}
// Channel represents a channel in the list response.
type Channel struct {
Name string `json:"name"`
Topic string `json:"topic"`
Members int `json:"members"`
CreatedAt string `json:"createdAt"`
}
// ServerInfo is the response from GET /api/v1/server.
type ServerInfo struct {
Name string `json:"name"`
MOTD string `json:"motd"`
Version string `json:"version"`
HashcashBits int `json:"hashcash_bits"` //nolint:tagliatelle
}
// MessagesResponse wraps polling results.
type MessagesResponse struct {
Messages []Message `json:"messages"`
LastID int64 `json:"lastId"`
}
// PollResult wraps the poll response including the cursor.
type PollResult struct {
Messages []Message
LastID int64
}
// ParseTS parses the message timestamp.
func (m *Message) ParseTS() time.Time {
t, err := time.Parse(time.RFC3339Nano, m.TS)
if err != nil {
return time.Now()
}
return t
}

912
internal/cli/app.go ノーマルファイル
ファイルの表示

@@ -0,0 +1,912 @@
// Package cli implements the neoirc-cli terminal client.
package cli
import (
"fmt"
"os"
"strings"
"sync"
"time"
api "git.eeqj.de/sneak/neoirc/internal/cli/api"
"git.eeqj.de/sneak/neoirc/pkg/irc"
)
const (
splitParts = 2
pollTimeout = 15
pollRetry = 2 * time.Second
timeFormat = "15:04"
)
// App holds the application state.
type App struct {
ui *UI
client *api.Client
mu sync.Mutex
nick string
target string
connected bool
lastQID int64
stopPoll chan struct{}
}
// Run creates and runs the CLI application.
func Run() {
app := &App{ //nolint:exhaustruct
ui: NewUI(),
nick: "guest",
}
app.ui.OnInput(app.handleInput)
app.ui.SetStatus(app.nick, "", "disconnected")
app.ui.AddStatus(
"Welcome to neoirc-cli — an IRC-style client",
)
app.ui.AddStatus(
"Type [yellow]/connect <server-url>" +
"[white] to begin, " +
"or [yellow]/help[white] for commands",
)
err := app.ui.Run()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func (a *App) handleInput(text string) {
if strings.HasPrefix(text, "/") {
a.handleCommand(text)
return
}
a.mu.Lock()
target := a.target
connected := a.connected
a.mu.Unlock()
if !connected {
a.ui.AddStatus(
"[red]Not connected. Use /connect <url>",
)
return
}
if target == "" {
a.ui.AddStatus(
"[red]No target. " +
"Use /join #channel or /query nick",
)
return
}
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
Command: irc.CmdPrivmsg,
To: target,
Body: []string{text},
})
if err != nil {
a.ui.AddStatus(
"[red]Send error: " + err.Error(),
)
return
}
timestamp := time.Now().Format(timeFormat)
a.mu.Lock()
nick := a.nick
a.mu.Unlock()
a.ui.AddLine(target, fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
timestamp, nick, text,
))
}
func (a *App) handleCommand(text string) {
parts := strings.SplitN(text, " ", splitParts)
cmd := strings.ToLower(parts[0])
args := ""
if len(parts) > 1 {
args = parts[1]
}
a.dispatchCommand(cmd, args)
}
func (a *App) dispatchCommand(cmd, args string) {
switch cmd {
case "/connect":
a.cmdConnect(args)
case "/nick":
a.cmdNick(args)
case "/join":
a.cmdJoin(args)
case "/part":
a.cmdPart(args)
case "/msg":
a.cmdMsg(args)
case "/query":
a.cmdQuery(args)
case "/topic":
a.cmdTopic(args)
case "/window", "/w":
a.cmdWindow(args)
case "/quit":
a.cmdQuit()
case "/help":
a.cmdHelp()
default:
a.dispatchInfoCommand(cmd, args)
}
}
func (a *App) dispatchInfoCommand(cmd, args string) {
switch cmd {
case "/names":
a.cmdNames()
case "/list":
a.cmdList()
case "/motd":
a.cmdMotd()
case "/who":
a.cmdWho(args)
case "/whois":
a.cmdWhois(args)
default:
a.ui.AddStatus(
"[red]Unknown command: " + cmd,
)
}
}
func (a *App) cmdConnect(serverURL string) {
if serverURL == "" {
a.ui.AddStatus(
"[red]Usage: /connect <server-url>",
)
return
}
serverURL = strings.TrimRight(serverURL, "/")
a.ui.AddStatus("Connecting to " + serverURL + "...")
a.mu.Lock()
nick := a.nick
a.mu.Unlock()
client := api.NewClient(serverURL)
resp, err := client.CreateSession(nick)
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]Connection failed: %v", err,
))
return
}
a.mu.Lock()
a.client = client
a.nick = resp.Nick
a.connected = true
a.lastQID = 0
a.mu.Unlock()
a.ui.AddStatus(fmt.Sprintf(
"[green]Connected! Nick: %s, Session: %d",
resp.Nick, resp.ID,
))
a.ui.SetStatus(resp.Nick, "", "connected")
a.stopPoll = make(chan struct{})
go a.pollLoop()
}
func (a *App) cmdNick(nick string) {
if nick == "" {
a.ui.AddStatus(
"[red]Usage: /nick <name>",
)
return
}
a.mu.Lock()
connected := a.connected
a.mu.Unlock()
if !connected {
a.mu.Lock()
a.nick = nick
a.mu.Unlock()
a.ui.AddStatus(
"Nick set to " + nick +
" (will be used on connect)",
)
return
}
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
Command: irc.CmdNick,
Body: []string{nick},
})
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]Nick change failed: %v", err,
))
return
}
a.mu.Lock()
a.nick = nick
target := a.target
a.mu.Unlock()
a.ui.SetStatus(nick, target, "connected")
a.ui.AddStatus("Nick changed to " + nick)
}
func (a *App) cmdJoin(channel string) {
if channel == "" {
a.ui.AddStatus(
"[red]Usage: /join #channel",
)
return
}
if !strings.HasPrefix(channel, "#") {
channel = "#" + channel
}
a.mu.Lock()
connected := a.connected
a.mu.Unlock()
if !connected {
a.ui.AddStatus("[red]Not connected")
return
}
err := a.client.JoinChannel(channel)
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]Join failed: %v", err,
))
return
}
a.mu.Lock()
a.target = channel
nick := a.nick
a.mu.Unlock()
a.ui.SwitchToBuffer(channel)
a.ui.AddLine(channel,
"[yellow]*** Joined "+channel,
)
a.ui.SetStatus(nick, channel, "connected")
}
func (a *App) cmdPart(channel string) {
a.mu.Lock()
if channel == "" {
channel = a.target
}
connected := a.connected
a.mu.Unlock()
if channel == "" ||
!strings.HasPrefix(channel, "#") {
a.ui.AddStatus("[red]No channel to part")
return
}
if !connected {
a.ui.AddStatus("[red]Not connected")
return
}
err := a.client.PartChannel(channel)
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]Part failed: %v", err,
))
return
}
a.ui.AddLine(channel,
"[yellow]*** Left "+channel,
)
a.mu.Lock()
if a.target == channel {
a.target = ""
}
nick := a.nick
a.mu.Unlock()
a.ui.SwitchBuffer(0)
a.ui.SetStatus(nick, "", "connected")
}
func (a *App) cmdMsg(args string) {
parts := strings.SplitN(args, " ", splitParts)
if len(parts) < splitParts {
a.ui.AddStatus(
"[red]Usage: /msg <nick> <text>",
)
return
}
target, text := parts[0], parts[1]
a.mu.Lock()
connected := a.connected
nick := a.nick
a.mu.Unlock()
if !connected {
a.ui.AddStatus("[red]Not connected")
return
}
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
Command: irc.CmdPrivmsg,
To: target,
Body: []string{text},
})
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]Send failed: %v", err,
))
return
}
timestamp := time.Now().Format(timeFormat)
a.ui.AddLine(target, fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
timestamp, nick, text,
))
}
func (a *App) cmdQuery(nick string) {
if nick == "" {
a.ui.AddStatus(
"[red]Usage: /query <nick>",
)
return
}
a.mu.Lock()
a.target = nick
myNick := a.nick
a.mu.Unlock()
a.ui.SwitchToBuffer(nick)
a.ui.SetStatus(myNick, nick, "connected")
}
func (a *App) cmdTopic(args string) {
a.mu.Lock()
target := a.target
connected := a.connected
a.mu.Unlock()
if !connected {
a.ui.AddStatus("[red]Not connected")
return
}
if !strings.HasPrefix(target, "#") {
a.ui.AddStatus("[red]Not in a channel")
return
}
if args == "" {
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
Command: irc.CmdTopic,
To: target,
})
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]Topic query failed: %v", err,
))
}
return
}
err := a.client.SendMessage(&api.Message{ //nolint:exhaustruct
Command: irc.CmdTopic,
To: target,
Body: []string{args},
})
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]Topic set failed: %v", err,
))
}
}
func (a *App) cmdNames() {
a.mu.Lock()
target := a.target
connected := a.connected
a.mu.Unlock()
if !connected {
a.ui.AddStatus("[red]Not connected")
return
}
if !strings.HasPrefix(target, "#") {
a.ui.AddStatus("[red]Not in a channel")
return
}
members, err := a.client.GetMembers(target)
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]Names failed: %v", err,
))
return
}
a.ui.AddLine(target, fmt.Sprintf(
"[cyan]*** Members of %s: %s",
target, strings.Join(members, " "),
))
}
func (a *App) cmdList() {
a.mu.Lock()
connected := a.connected
a.mu.Unlock()
if !connected {
a.ui.AddStatus("[red]Not connected")
return
}
channels, err := a.client.ListChannels()
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]List failed: %v", err,
))
return
}
a.ui.AddStatus("[cyan]*** Channel list:")
for _, ch := range channels {
a.ui.AddStatus(fmt.Sprintf(
" %s (%d members) %s",
ch.Name, ch.Members, ch.Topic,
))
}
a.ui.AddStatus("[cyan]*** End of channel list")
}
func (a *App) cmdMotd() {
a.mu.Lock()
connected := a.connected
a.mu.Unlock()
if !connected {
a.ui.AddStatus("[red]Not connected")
return
}
err := a.client.SendMessage(
&api.Message{Command: irc.CmdMotd}, //nolint:exhaustruct
)
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]MOTD failed: %v", err,
))
}
}
func (a *App) cmdWho(args string) {
a.mu.Lock()
connected := a.connected
target := a.target
a.mu.Unlock()
if !connected {
a.ui.AddStatus("[red]Not connected")
return
}
channel := args
if channel == "" {
channel = target
}
if channel == "" ||
!strings.HasPrefix(channel, "#") {
a.ui.AddStatus(
"[red]Usage: /who #channel",
)
return
}
err := a.client.SendMessage(
&api.Message{ //nolint:exhaustruct
Command: irc.CmdWho, To: channel,
},
)
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]WHO failed: %v", err,
))
}
}
func (a *App) cmdWhois(args string) {
a.mu.Lock()
connected := a.connected
a.mu.Unlock()
if !connected {
a.ui.AddStatus("[red]Not connected")
return
}
if args == "" {
a.ui.AddStatus(
"[red]Usage: /whois <nick>",
)
return
}
err := a.client.SendMessage(
&api.Message{ //nolint:exhaustruct
Command: irc.CmdWhois, To: args,
},
)
if err != nil {
a.ui.AddStatus(fmt.Sprintf(
"[red]WHOIS failed: %v", err,
))
}
}
func (a *App) cmdWindow(args string) {
if args == "" {
a.ui.AddStatus(
"[red]Usage: /window <number>",
)
return
}
var bufIndex int
_, _ = fmt.Sscanf(args, "%d", &bufIndex)
a.ui.SwitchBuffer(bufIndex)
a.mu.Lock()
nick := a.nick
a.mu.Unlock()
if bufIndex >= 0 && bufIndex < a.ui.BufferCount() {
buf := a.ui.buffers[bufIndex]
if buf.Name != "(status)" {
a.mu.Lock()
a.target = buf.Name
a.mu.Unlock()
a.ui.SetStatus(
nick, buf.Name, "connected",
)
} else {
a.ui.SetStatus(nick, "", "connected")
}
}
}
func (a *App) cmdQuit() {
a.mu.Lock()
if a.connected && a.client != nil {
_ = a.client.SendMessage(
&api.Message{Command: irc.CmdQuit}, //nolint:exhaustruct
)
}
if a.stopPoll != nil {
close(a.stopPoll)
}
a.mu.Unlock()
a.ui.Stop()
}
func (a *App) cmdHelp() {
help := []string{
"[cyan]*** neoirc-cli commands:",
" /connect <url> — Connect to server",
" /nick <name> — Change nickname",
" /join #channel — Join channel",
" /part [#chan] — Leave channel",
" /msg <nick> <text> — Send DM",
" /query <nick> — Open DM window",
" /topic [text] — View/set topic",
" /names — List channel members",
" /list — List channels",
" /who [#channel] — List users in channel",
" /whois <nick> — Show user info",
" /motd — Show message of the day",
" /window <n> — Switch buffer",
" /quit — Disconnect and exit",
" /help — This help",
" Plain text sends to current target.",
}
for _, line := range help {
a.ui.AddStatus(line)
}
}
// pollLoop long-polls for messages in the background.
func (a *App) pollLoop() {
for {
select {
case <-a.stopPoll:
return
default:
}
a.mu.Lock()
client := a.client
lastQID := a.lastQID
a.mu.Unlock()
if client == nil {
return
}
result, err := client.PollMessages(
lastQID, pollTimeout,
)
if err != nil {
time.Sleep(pollRetry)
continue
}
if result.LastID > 0 {
a.mu.Lock()
a.lastQID = result.LastID
a.mu.Unlock()
}
for i := range result.Messages {
a.handleServerMessage(&result.Messages[i])
}
}
}
func (a *App) handleServerMessage(msg *api.Message) {
timestamp := a.formatTS(msg)
a.mu.Lock()
myNick := a.nick
a.mu.Unlock()
switch msg.Command {
case irc.CmdPrivmsg:
a.handlePrivmsgEvent(msg, timestamp, myNick)
case irc.CmdJoin:
a.handleJoinEvent(msg, timestamp)
case irc.CmdPart:
a.handlePartEvent(msg, timestamp)
case irc.CmdQuit:
a.handleQuitEvent(msg, timestamp)
case irc.CmdNick:
a.handleNickEvent(msg, timestamp, myNick)
case irc.CmdNotice:
a.handleNoticeEvent(msg, timestamp)
case irc.CmdTopic:
a.handleTopicEvent(msg, timestamp)
default:
a.handleDefaultEvent(msg, timestamp)
}
}
func (a *App) formatTS(msg *api.Message) string {
if msg.TS != "" {
return msg.ParseTS().UTC().Format(timeFormat)
}
return time.Now().Format(timeFormat)
}
func (a *App) handlePrivmsgEvent(
msg *api.Message, timestamp, myNick string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
if msg.From == myNick {
return
}
target := msg.To
if !strings.HasPrefix(target, "#") {
target = msg.From
}
a.ui.AddLine(target, fmt.Sprintf(
"[gray]%s [green]<%s>[white] %s",
timestamp, msg.From, text,
))
}
func (a *App) handleJoinEvent(
msg *api.Message, timestamp string,
) {
if msg.To == "" {
return
}
a.ui.AddLine(msg.To, fmt.Sprintf(
"[gray]%s [yellow]*** %s has joined %s",
timestamp, msg.From, msg.To,
))
}
func (a *App) handlePartEvent(
msg *api.Message, timestamp string,
) {
if msg.To == "" {
return
}
lines := msg.BodyLines()
reason := strings.Join(lines, " ")
if reason != "" {
a.ui.AddLine(msg.To, fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s (%s)",
timestamp, msg.From, msg.To, reason,
))
} else {
a.ui.AddLine(msg.To, fmt.Sprintf(
"[gray]%s [yellow]*** %s has left %s",
timestamp, msg.From, msg.To,
))
}
}
func (a *App) handleQuitEvent(
msg *api.Message, timestamp string,
) {
lines := msg.BodyLines()
reason := strings.Join(lines, " ")
if reason != "" {
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit (%s)",
timestamp, msg.From, reason,
))
} else {
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [yellow]*** %s has quit",
timestamp, msg.From,
))
}
}
func (a *App) handleNickEvent(
msg *api.Message, timestamp, myNick string,
) {
lines := msg.BodyLines()
newNick := ""
if len(lines) > 0 {
newNick = lines[0]
}
if msg.From == myNick && newNick != "" {
a.mu.Lock()
a.nick = newNick
target := a.target
a.mu.Unlock()
a.ui.SetStatus(newNick, target, "connected")
}
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [yellow]*** %s is now known as %s",
timestamp, msg.From, newNick,
))
}
func (a *App) handleNoticeEvent(
msg *api.Message, timestamp string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [magenta]--%s-- %s",
timestamp, msg.From, text,
))
}
func (a *App) handleTopicEvent(
msg *api.Message, timestamp string,
) {
if msg.To == "" {
return
}
lines := msg.BodyLines()
text := strings.Join(lines, " ")
a.ui.AddLine(msg.To, fmt.Sprintf(
"[gray]%s [cyan]*** %s set topic: %s",
timestamp, msg.From, text,
))
}
func (a *App) handleDefaultEvent(
msg *api.Message, timestamp string,
) {
lines := msg.BodyLines()
text := strings.Join(lines, " ")
if text != "" {
a.ui.AddStatus(fmt.Sprintf(
"[gray]%s [white][%s] %s",
timestamp, msg.Command, text,
))
}
}

294
internal/cli/ui.go ノーマルファイル
ファイルの表示

@@ -0,0 +1,294 @@
package cli
import (
"fmt"
"strings"
"time"
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
)
// Buffer holds messages for a channel/DM/status window.
type Buffer struct {
Name string
Lines []string
Unread int
}
// UI manages the terminal interface.
type UI struct {
app *tview.Application
messages *tview.TextView
statusBar *tview.TextView
input *tview.InputField
layout *tview.Flex
buffers []*Buffer
currentBuffer int
onInput func(string)
}
// NewUI creates the tview-based IRC-like UI.
func NewUI() *UI {
ui := &UI{ //nolint:exhaustruct,varnamelen // fields set below; ui is idiomatic
app: tview.NewApplication(),
buffers: []*Buffer{
{Name: "(status)", Lines: nil, Unread: 0},
},
}
ui.initMessages()
ui.initStatusBar()
ui.initInput()
ui.initKeyCapture()
ui.layout = tview.NewFlex().
SetDirection(tview.FlexRow).
AddItem(ui.messages, 0, 1, false).
AddItem(ui.statusBar, 1, 0, false).
AddItem(ui.input, 1, 0, true)
ui.app.SetRoot(ui.layout, true)
ui.app.SetFocus(ui.input)
return ui
}
// Run starts the UI event loop (blocks).
func (ui *UI) Run() error {
err := ui.app.Run()
if err != nil {
return fmt.Errorf("run ui: %w", err)
}
return nil
}
// Stop stops the UI.
func (ui *UI) Stop() {
ui.app.Stop()
}
// OnInput sets the callback for user input.
func (ui *UI) OnInput(fn func(string)) {
ui.onInput = fn
}
// AddLine adds a line to the specified buffer.
func (ui *UI) AddLine(bufferName, line string) {
ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(bufferName)
buf.Lines = append(buf.Lines, line)
cur := ui.buffers[ui.currentBuffer]
if cur != buf {
buf.Unread++
ui.refreshStatusBar()
}
if cur == buf {
_, _ = fmt.Fprintln(ui.messages, line)
}
})
}
// AddStatus adds a line to the status buffer.
func (ui *UI) AddStatus(line string) {
ts := time.Now().Format("15:04")
ui.AddLine(
"(status)",
"[gray]"+ts+"[white] "+line,
)
}
// SwitchBuffer switches to the buffer at index n.
func (ui *UI) SwitchBuffer(bufIndex int) {
ui.app.QueueUpdateDraw(func() {
if bufIndex < 0 || bufIndex >= len(ui.buffers) {
return
}
ui.currentBuffer = bufIndex
buf := ui.buffers[bufIndex]
buf.Unread = 0
ui.messages.Clear()
for _, line := range buf.Lines {
_, _ = fmt.Fprintln(ui.messages, line)
}
ui.messages.ScrollToEnd()
ui.refreshStatusBar()
})
}
// SwitchToBuffer switches to named buffer, creating if
// needed.
func (ui *UI) SwitchToBuffer(name string) {
ui.app.QueueUpdateDraw(func() {
buf := ui.getOrCreateBuffer(name)
for i, b := range ui.buffers {
if b == buf {
ui.currentBuffer = i
break
}
}
buf.Unread = 0
ui.messages.Clear()
for _, line := range buf.Lines {
_, _ = fmt.Fprintln(ui.messages, line)
}
ui.messages.ScrollToEnd()
ui.refreshStatusBar()
})
}
// SetStatus updates the status bar text.
func (ui *UI) SetStatus(
nick, target, connStatus string,
) {
ui.app.QueueUpdateDraw(func() {
ui.renderStatusBar(nick, target, connStatus)
})
}
// BufferCount returns the number of buffers.
func (ui *UI) BufferCount() int {
return len(ui.buffers)
}
// BufferIndex returns the index of a named buffer.
func (ui *UI) BufferIndex(name string) int {
for i, buf := range ui.buffers {
if buf.Name == name {
return i
}
}
return -1
}
func (ui *UI) initMessages() {
ui.messages = tview.NewTextView().
SetDynamicColors(true).
SetScrollable(true).
SetWordWrap(true).
SetChangedFunc(func() {
ui.app.Draw()
})
ui.messages.SetBorder(false)
}
func (ui *UI) initStatusBar() {
ui.statusBar = tview.NewTextView().
SetDynamicColors(true)
ui.statusBar.SetBackgroundColor(tcell.ColorNavy)
ui.statusBar.SetTextColor(tcell.ColorWhite)
}
func (ui *UI) initInput() {
ui.input = tview.NewInputField().
SetFieldBackgroundColor(tcell.ColorBlack).
SetFieldTextColor(tcell.ColorWhite)
ui.input.SetDoneFunc(func(key tcell.Key) {
if key != tcell.KeyEnter {
return
}
text := ui.input.GetText()
if text == "" {
return
}
ui.input.SetText("")
if ui.onInput != nil {
ui.onInput(text)
}
})
}
func (ui *UI) initKeyCapture() {
ui.app.SetInputCapture(
func(event *tcell.EventKey) *tcell.EventKey {
if event.Modifiers()&tcell.ModAlt == 0 {
return event
}
r := event.Rune()
if r >= '0' && r <= '9' {
idx := int(r - '0')
ui.SwitchBuffer(idx)
return nil
}
return event
},
)
}
func (ui *UI) refreshStatusBar() {
// Placeholder; full refresh needs nick/target context.
}
func (ui *UI) renderStatusBar(
nick, target, connStatus string,
) {
var unreadParts []string
for i, buf := range ui.buffers {
if buf.Unread > 0 {
unreadParts = append(unreadParts,
fmt.Sprintf(
"%d:%s(%d)",
i, buf.Name, buf.Unread,
),
)
}
}
unread := ""
if len(unreadParts) > 0 {
unread = " [Act: " +
strings.Join(unreadParts, ",") + "]"
}
bufInfo := fmt.Sprintf(
"[%d:%s]",
ui.currentBuffer,
ui.buffers[ui.currentBuffer].Name,
)
ui.statusBar.Clear()
_, _ = fmt.Fprintf(ui.statusBar,
" [%s] %s %s %s%s",
connStatus, nick, bufInfo, target, unread,
)
}
func (ui *UI) getOrCreateBuffer(name string) *Buffer {
for _, buf := range ui.buffers {
if buf.Name == name {
return buf
}
}
buf := &Buffer{Name: name, Lines: nil, Unread: 0}
ui.buffers = append(ui.buffers, buf)
return buf
}

ファイルの表示

@@ -45,6 +45,7 @@ type Config struct {
ServerName string
FederationKey string
SessionIdleTimeout string
HashcashBits int
params *Params
log *slog.Logger
}
@@ -76,6 +77,7 @@ func New(
viper.SetDefault("SERVER_NAME", "")
viper.SetDefault("FEDERATION_KEY", "")
viper.SetDefault("SESSION_IDLE_TIMEOUT", "720h")
viper.SetDefault("NEOIRC_HASHCASH_BITS", "20")
err := viper.ReadInConfig()
if err != nil {
@@ -101,6 +103,7 @@ func New(
ServerName: viper.GetString("SERVER_NAME"),
FederationKey: viper.GetString("FEDERATION_KEY"),
SessionIdleTimeout: viper.GetString("SESSION_IDLE_TIMEOUT"),
HashcashBits: viper.GetInt("NEOIRC_HASHCASH_BITS"),
log: log,
params: &params,
}

ファイルの表示

@@ -146,7 +146,8 @@ func (hdlr *Handlers) handleCreateSession(
request *http.Request,
) {
type createRequest struct {
Nick string `json:"nick"`
Nick string `json:"nick"`
Hashcash string `json:"pow_token,omitempty"` //nolint:tagliatelle
}
var payload createRequest
@@ -162,6 +163,32 @@ func (hdlr *Handlers) handleCreateSession(
return
}
// Validate hashcash proof-of-work if configured.
if hdlr.params.Config.HashcashBits > 0 {
if payload.Hashcash == "" {
hdlr.respondError(
writer, request,
"hashcash proof-of-work required",
http.StatusPaymentRequired,
)
return
}
err = hdlr.hashcashVal.Validate(
payload.Hashcash, hdlr.params.Config.HashcashBits,
)
if err != nil {
hdlr.respondError(
writer, request,
"invalid hashcash stamp: "+err.Error(),
http.StatusPaymentRequired,
)
return
}
}
payload.Nick = strings.TrimSpace(payload.Nick)
if !validNickRe.MatchString(payload.Nick) {
@@ -2392,11 +2419,19 @@ func (hdlr *Handlers) HandleServerInfo() http.HandlerFunc {
return
}
hdlr.respondJSON(writer, request, map[string]any{
resp := map[string]any{
"name": hdlr.params.Config.ServerName,
"version": hdlr.params.Globals.Version,
"motd": hdlr.params.Config.MOTD,
"users": users,
}, http.StatusOK)
}
if hdlr.params.Config.HashcashBits > 0 {
resp["hashcash_bits"] = hdlr.params.Config.HashcashBits
}
hdlr.respondJSON(
writer, request, resp, http.StatusOK,
)
}
}

ファイルの表示

@@ -85,6 +85,7 @@ func newTestServer(
cfg.DBURL = dbURL
cfg.Port = 0
cfg.HashcashBits = 0
return cfg, nil
},

ファイルの表示

@@ -13,6 +13,7 @@ import (
"git.eeqj.de/sneak/neoirc/internal/config"
"git.eeqj.de/sneak/neoirc/internal/db"
"git.eeqj.de/sneak/neoirc/internal/globals"
"git.eeqj.de/sneak/neoirc/internal/hashcash"
"git.eeqj.de/sneak/neoirc/internal/healthcheck"
"git.eeqj.de/sneak/neoirc/internal/logger"
"go.uber.org/fx"
@@ -39,6 +40,7 @@ type Handlers struct {
log *slog.Logger
hc *healthcheck.Healthcheck
broker *broker.Broker
hashcashVal *hashcash.Validator
cancelCleanup context.CancelFunc
}
@@ -47,11 +49,17 @@ func New(
lifecycle fx.Lifecycle,
params Params,
) (*Handlers, error) {
resource := params.Config.ServerName
if resource == "" {
resource = "neoirc"
}
hdlr := &Handlers{ //nolint:exhaustruct // cancelCleanup set in startCleanup
params: &params,
log: params.Logger.Get(),
hc: params.Healthcheck,
broker: broker.New(),
params: &params,
log: params.Logger.Get(),
hc: params.Healthcheck,
broker: broker.New(),
hashcashVal: hashcash.NewValidator(resource),
}
lifecycle.Append(fx.Hook{

277
internal/hashcash/hashcash.go ノーマルファイル
ファイルの表示

@@ -0,0 +1,277 @@
// Package hashcash implements SHA-256-based hashcash
// proof-of-work validation for abuse prevention.
//
// Stamp format: 1:bits:YYMMDD:resource::counter.
//
// The SHA-256 hash of the entire stamp string must have
// at least `bits` leading zero bits.
package hashcash
import (
"crypto/sha256"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
)
const (
// stampVersion is the only supported hashcash version.
stampVersion = "1"
// stampFields is the number of fields in a stamp.
stampFields = 6
// maxStampAge is how old a stamp can be before
// rejection.
maxStampAge = 48 * time.Hour
// maxFutureSkew allows stamps slightly in the future.
maxFutureSkew = 1 * time.Hour
// pruneInterval controls how often expired stamps are
// removed from the spent set.
pruneInterval = 10 * time.Minute
// dateFormatShort is the YYMMDD date layout.
dateFormatShort = "060102"
// dateFormatLong is the YYMMDDHHMMSS date layout.
dateFormatLong = "060102150405"
// dateShortLen is the length of YYMMDD.
dateShortLen = 6
// dateLongLen is the length of YYMMDDHHMMSS.
dateLongLen = 12
// bitsPerByte is the number of bits in a byte.
bitsPerByte = 8
// fullByteMask is 0xFF, a mask for all bits in a byte.
fullByteMask = 0xFF
)
var (
errInvalidFields = errors.New("invalid stamp field count")
errBadVersion = errors.New("unsupported stamp version")
errInsufficientBits = errors.New("insufficient difficulty")
errWrongResource = errors.New("wrong resource")
errStampExpired = errors.New("stamp expired")
errStampFuture = errors.New("stamp date in future")
errProofFailed = errors.New("proof-of-work failed")
errStampReused = errors.New("stamp already used")
errBadDateFormat = errors.New("unrecognized date format")
)
// Validator checks hashcash stamps for validity and
// prevents replay attacks via an in-memory spent set.
type Validator struct {
resource string
mu sync.Mutex
spent map[string]time.Time
}
// NewValidator creates a Validator for the given resource.
func NewValidator(resource string) *Validator {
validator := &Validator{
resource: resource,
mu: sync.Mutex{},
spent: make(map[string]time.Time),
}
go validator.pruneLoop()
return validator
}
// Validate checks a hashcash stamp. It returns nil if the
// stamp is valid and has not been seen before.
func (v *Validator) Validate(
stamp string,
requiredBits int,
) error {
if requiredBits <= 0 {
return nil
}
parts := strings.Split(stamp, ":")
if len(parts) != stampFields {
return fmt.Errorf(
"%w: expected %d, got %d",
errInvalidFields, stampFields, len(parts),
)
}
version := parts[0]
bitsStr := parts[1]
dateStr := parts[2]
resource := parts[3]
if err := v.validateHeader(
version, bitsStr, resource, requiredBits,
); err != nil {
return err
}
stampTime, err := parseStampDate(dateStr)
if err != nil {
return err
}
if err := validateTime(stampTime); err != nil {
return err
}
if err := validateProof(
stamp, requiredBits,
); err != nil {
return err
}
return v.checkAndRecordStamp(stamp, stampTime)
}
func (v *Validator) validateHeader(
version, bitsStr, resource string,
requiredBits int,
) error {
if version != stampVersion {
return fmt.Errorf(
"%w: %s", errBadVersion, version,
)
}
claimedBits, err := strconv.Atoi(bitsStr)
if err != nil || claimedBits < requiredBits {
return fmt.Errorf(
"%w: need %d bits",
errInsufficientBits, requiredBits,
)
}
if resource != v.resource {
return fmt.Errorf(
"%w: got %q, want %q",
errWrongResource, resource, v.resource,
)
}
return nil
}
func validateTime(stampTime time.Time) error {
now := time.Now()
if now.Sub(stampTime) > maxStampAge {
return errStampExpired
}
if stampTime.Sub(now) > maxFutureSkew {
return errStampFuture
}
return nil
}
func validateProof(stamp string, requiredBits int) error {
hash := sha256.Sum256([]byte(stamp))
if !hasLeadingZeroBits(hash[:], requiredBits) {
return fmt.Errorf(
"%w: need %d leading zero bits",
errProofFailed, requiredBits,
)
}
return nil
}
func (v *Validator) checkAndRecordStamp(
stamp string,
stampTime time.Time,
) error {
v.mu.Lock()
defer v.mu.Unlock()
if _, ok := v.spent[stamp]; ok {
return errStampReused
}
v.spent[stamp] = stampTime
return nil
}
// hasLeadingZeroBits checks if the hash has at least n
// leading zero bits.
func hasLeadingZeroBits(hash []byte, numBits int) bool {
fullBytes := numBits / bitsPerByte
remainBits := numBits % bitsPerByte
for idx := range fullBytes {
if hash[idx] != 0 {
return false
}
}
if remainBits > 0 && fullBytes < len(hash) {
mask := byte(
fullByteMask << (bitsPerByte - remainBits),
)
if hash[fullBytes]&mask != 0 {
return false
}
}
return true
}
// parseStampDate parses a hashcash date stamp.
// Supports YYMMDD and YYMMDDHHMMSS formats.
func parseStampDate(dateStr string) (time.Time, error) {
switch len(dateStr) {
case dateShortLen:
parsed, err := time.Parse(
dateFormatShort, dateStr,
)
if err != nil {
return time.Time{}, fmt.Errorf(
"parse date: %w", err,
)
}
return parsed, nil
case dateLongLen:
parsed, err := time.Parse(
dateFormatLong, dateStr,
)
if err != nil {
return time.Time{}, fmt.Errorf(
"parse date: %w", err,
)
}
return parsed, nil
default:
return time.Time{}, fmt.Errorf(
"%w: %q", errBadDateFormat, dateStr,
)
}
}
// pruneLoop periodically removes expired stamps from the
// spent set.
func (v *Validator) pruneLoop() {
ticker := time.NewTicker(pruneInterval)
defer ticker.Stop()
for range ticker.C {
v.prune()
}
}
func (v *Validator) prune() {
cutoff := time.Now().Add(-maxStampAge)
v.mu.Lock()
defer v.mu.Unlock()
for stamp, stampTime := range v.spent {
if stampTime.Before(cutoff) {
delete(v.spent, stamp)
}
}
}

261
internal/hashcash/hashcash_test.go ノーマルファイル
ファイルの表示

@@ -0,0 +1,261 @@
package hashcash_test
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"math/big"
"testing"
"time"
"git.eeqj.de/sneak/neoirc/internal/hashcash"
)
const testBits = 2
// mintStampWithDate creates a valid hashcash stamp using
// the given date string.
func mintStampWithDate(
tb testing.TB,
bits int,
resource string,
date string,
) string {
tb.Helper()
prefix := fmt.Sprintf(
"1:%d:%s:%s::", bits, date, resource,
)
for {
counterVal, err := rand.Int(
rand.Reader, big.NewInt(1<<48),
)
if err != nil {
tb.Fatalf("random counter: %v", err)
}
stamp := prefix + hex.EncodeToString(
counterVal.Bytes(),
)
hash := sha256.Sum256([]byte(stamp))
if hasLeadingZeroBits(hash[:], bits) {
return stamp
}
}
}
// hasLeadingZeroBits checks if hash has at least numBits
// leading zero bits. Duplicated here for test minting.
func hasLeadingZeroBits(
hash []byte,
numBits int,
) bool {
fullBytes := numBits / 8
remainBits := numBits % 8
for idx := range fullBytes {
if hash[idx] != 0 {
return false
}
}
if remainBits > 0 && fullBytes < len(hash) {
mask := byte(0xFF << (8 - remainBits))
if hash[fullBytes]&mask != 0 {
return false
}
}
return true
}
func todayDate() string {
return time.Now().UTC().Format("060102")
}
func TestMintAndValidate(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
stamp := mintStampWithDate(
t, testBits, "test-resource", todayDate(),
)
err := validator.Validate(stamp, testBits)
if err != nil {
t.Fatalf("valid stamp rejected: %v", err)
}
}
func TestReplayDetection(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
stamp := mintStampWithDate(
t, testBits, "test-resource", todayDate(),
)
err := validator.Validate(stamp, testBits)
if err != nil {
t.Fatalf("first use failed: %v", err)
}
err = validator.Validate(stamp, testBits)
if err == nil {
t.Fatal("replay not detected")
}
}
func TestResourceMismatch(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("correct-resource")
stamp := mintStampWithDate(
t, testBits, "wrong-resource", todayDate(),
)
err := validator.Validate(stamp, testBits)
if err == nil {
t.Fatal("expected resource mismatch error")
}
}
func TestInvalidStampFormat(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
err := validator.Validate(
"not:a:valid:stamp", testBits,
)
if err == nil {
t.Fatal("expected error for bad format")
}
}
func TestBadVersion(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
stamp := fmt.Sprintf(
"2:%d:%s:%s::abc123",
testBits, todayDate(), "test-resource",
)
err := validator.Validate(stamp, testBits)
if err == nil {
t.Fatal("expected bad version error")
}
}
func TestInsufficientDifficulty(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
// Claimed bits=1, but we require testBits=2.
stamp := fmt.Sprintf(
"1:1:%s:%s::counter",
todayDate(), "test-resource",
)
err := validator.Validate(stamp, testBits)
if err == nil {
t.Fatal("expected insufficient bits error")
}
}
func TestExpiredStamp(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
oldDate := time.Now().Add(-72 * time.Hour).
UTC().Format("060102")
stamp := mintStampWithDate(
t, testBits, "test-resource", oldDate,
)
err := validator.Validate(stamp, testBits)
if err == nil {
t.Fatal("expected expired stamp error")
}
}
func TestZeroBitsSkipsValidation(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
err := validator.Validate("garbage", 0)
if err != nil {
t.Fatalf("zero bits should skip: %v", err)
}
}
func TestLongDateFormat(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
longDate := time.Now().UTC().Format("060102150405")
stamp := mintStampWithDate(
t, testBits, "test-resource", longDate,
)
err := validator.Validate(stamp, testBits)
if err != nil {
t.Fatalf("long date stamp rejected: %v", err)
}
}
func TestBadDateFormat(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
stamp := fmt.Sprintf(
"1:%d:BADDATE:%s::counter",
testBits, "test-resource",
)
err := validator.Validate(stamp, testBits)
if err == nil {
t.Fatal("expected bad date error")
}
}
func TestMultipleUniqueStamps(t *testing.T) {
t.Parallel()
validator := hashcash.NewValidator("test-resource")
for range 5 {
stamp := mintStampWithDate(
t, testBits, "test-resource", todayDate(),
)
err := validator.Validate(stamp, testBits)
if err != nil {
t.Fatalf("unique stamp rejected: %v", err)
}
}
}
func TestHigherBitsStillValid(t *testing.T) {
t.Parallel()
// Mint with bits=4 but validate requiring only 2.
validator := hashcash.NewValidator("test-resource")
stamp := mintStampWithDate(
t, 4, "test-resource", todayDate(),
)
err := validator.Validate(stamp, testBits)
if err != nil {
t.Fatalf(
"higher-difficulty stamp rejected: %v",
err,
)
}
}