diff --git a/BACKUP_CONFIG_GUIDE.md b/BACKUP_CONFIG_GUIDE.md deleted file mode 100644 index e0549f1f9..000000000 --- a/BACKUP_CONFIG_GUIDE.md +++ /dev/null @@ -1,215 +0,0 @@ -# 🚀 OPTIMIERTE BACKUP CONFIGURATION - PTERODACTYL WINGS - -## 🎯 EMPFOHLENE PRODUCTION CONFIG - -```yaml -# config.yml - Optimierte Backup-Einstellungen - -system: - # Standard System-Einstellungen... - data: "/var/lib/pterodactyl/volumes" - - # OPTIMIERTE BACKUP CONFIGURATION - backups: - # I/O Write-Limit in MiB/s (0 = unlimited) - write_limit: 0 - - # ZSTD für bessere Performance (EMPFOHLEN!) - format: "zstd" - - # Compression Level - compression_level: "best_speed" # Oder "best_compression" für mehr Platz - - # SMART SFTP SECURITY (Brute-Force Protection) - sftp: - bind_address: "0.0.0.0" - bind_port: 2022 - read_only: false - - # INTELLIGENT SECURITY - security: - enabled: true - - thresholds: - attempts_per_minute: 6 # 6+ = 5min block - attempts_per_hour: 15 # Eskalation - attempts_per_day: 50 # Reputation impact - - blocking: - base_block_minutes: 5 # Smart: 5min start - escalation_factor: 2.0 # 2x bei Wiederholung - max_block_hours: 24 # Max 24h block - decay_factor: 0.8 # Forgiveness over time - - reputation: - enabled: true - memory_days: 7 - block_threshold: -50 - good_behavior_bonus: 5 - bad_behavior_penalty: -10 -``` - -## ⚡ PERFORMANCE VERGLEICH - -### ZSTD vs GZIP Backups: -```yaml -# ALTE CONFIG (langram) -format: "gzip" # ❌ Langsam, alte Technologie -compression_level: "best_compression" # ❌ Sehr langsam - -# NEUE CONFIG (optimal) -format: "zstd" # ✅ 3-5x schneller -compression_level: "best_speed" # ✅ Balance Speed/Size -``` - -### REAL-WORLD PERFORMANCE: -- **10GB Server Backup:** - - GZIP: ~45 Minuten - - ZSTD: ~15 Minuten ⚡ **3x SCHNELLER** - -- **Backup Sizes:** - - GZIP: ~3.2GB - - ZSTD: ~2.8GB ⚡ **12% KLEINER** - -## 🎛️ VERSCHIEDENE PERFORMANCE PROFILES - -### 1. MAXIMUM SPEED (Empfohlen für große Server) -```yaml -backups: - format: "zstd" - compression_level: "best_speed" - write_limit: 0 # Unlimited I/O -``` -**Use Case:** Große Game-Server, wo Backup-Zeit kritisch ist - -### 2. BALANCED (Empfohlen für die meisten) -```yaml -backups: - format: "zstd" - compression_level: "best_speed" - write_limit: 100 # 100 MiB/s limit -``` -**Use Case:** Standard Production-Setup - -### 3. MAXIMUM COMPRESSION (für limitierten Speicher) -```yaml -backups: - format: "zstd" - compression_level: "best_compression" - write_limit: 50 # Langsamer I/O -``` -**Use Case:** Wenn Speicherplatz sehr limitiert ist - -### 4. LEGACY COMPATIBILITY (nur wenn nötig) -```yaml -backups: - format: "gzip" # Nur für Backward-Compatibility - compression_level: "best_speed" -``` -**Use Case:** Wenn alte Restore-Tools ZSTD nicht unterstützen - -## 🔧 MIGRATION STRATEGY - -### Phase 1: Vorbereitung (Woche 1) -```yaml -# Erstmal sicher bleiben -format: "gzip" -compression_level: "best_speed" -``` - -### Phase 2: ZSTD Rollout (Woche 2) -```yaml -# Schrittweise auf ZSTD umstellen -format: "zstd" -compression_level: "best_speed" -``` - -### Phase 3: Optimierung (Woche 3+) -```yaml -# Performance nach Bedarf anpassen -format: "zstd" -compression_level: "best_speed" # oder "best_compression" -write_limit: 0 # je nach I/O-Kapazität -``` - -## 🛡️ SECURITY HARDENING - -### Für High-Security Environments: -```yaml -sftp: - security: - enabled: true - thresholds: - attempts_per_minute: 3 # Stricter: nur 3 Versuche - attempts_per_hour: 8 # Weniger Toleranz - blocking: - base_block_minutes: 10 # Längere initiale Blocks - max_block_hours: 48 # Bis zu 2 Tage Block - reputation: - block_threshold: -30 # Schneller blocken - bad_behavior_penalty: -15 # Härtere Bestrafung -``` - -### Für Development/Testing: -```yaml -sftp: - security: - enabled: true - thresholds: - attempts_per_minute: 10 # Mehr Toleranz - attempts_per_hour: 25 - blocking: - base_block_minutes: 2 # Kurze Blocks - max_block_hours: 4 # Maximal 4h - reputation: - block_threshold: -70 # Mehr Geduld - decay_factor: 0.9 # Schneller vergeben -``` - -## 📊 MONITORING CONFIG - -```yaml -# Zusätzlich in deiner config.yml für besseres Logging: -debug: false # true nur für Development -log_level: "info" # "debug" für detaillierte Security-Logs - -# Environment-specific: -api: - host: "0.0.0.0" - port: 8080 - ssl: - enabled: true # HTTPS für Production - cert: "/etc/ssl/certs/wings.crt" - key: "/etc/ssl/private/wings.key" -``` - -## 🎯 FINAL RECOMMENDATION - -**Für Production (empfohlen):** -```yaml -system: - backups: - format: "zstd" - compression_level: "best_speed" - write_limit: 0 - - sftp: - security: - enabled: true - thresholds: - attempts_per_minute: 6 - attempts_per_hour: 15 - blocking: - base_block_minutes: 5 - escalation_factor: 2.0 - max_block_hours: 24 -``` - -**Benefits:** -- ⚡ **3x schnellere Backups** -- 💾 **12% kleinere Files** -- 🛡️ **Intelligent Brute-Force Protection** -- 🔄 **100% Backward Compatibility** -- 📈 **Smart Escalation System** - -**Diese Config macht deine Wings-Installation maximal performant und sicher! 🚀** \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index a918186c6..58c781b35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## v1.13.1 +### Security +* Backup restore downloads are now hardened against SSRF: remote restore links are validated and may not resolve to private, loopback, link-local or other blocked address ranges unless permitted via the new `restore_host_allowlist` config option. +* Backup identifiers are now strictly validated as UUIDs, preventing path traversal in backup paths and operations. +* Hardening fixes merged from upstream security advisories (filesystem quota, disk space and SFTP handling). + +### Added +* `system.backups.restore_host_allowlist` to allow backup restore downloads to reach otherwise blocked private/internal destinations. + +### Fixed +* Improved quota, server, registry, filesystem and backup handling (upstream v1.13.1). + ## v1.13.0 ### Fixed * Empty folders in uploaded artifacts are now preserved ([#325](https://github.com/pterodactyl/wings/pull/325)) diff --git a/COMPRESSION_UPGRADE.md b/COMPRESSION_UPGRADE.md deleted file mode 100644 index 1375663f5..000000000 --- a/COMPRESSION_UPGRADE.md +++ /dev/null @@ -1,168 +0,0 @@ -# ZSTD Compression Implementation - -## Overview - -This implementation adds support for ZSTD compression in Pterodactyl Wings backup system while maintaining 100% backward compatibility with existing GZIP backups. - -## Changes Made - -### 1. Configuration Support -- Added `format` field to `config.yaml` under `system.backups` -- Supported values: `"gzip"` (default), `"zstd"`, `"none"` -- Maintains backward compatibility - defaults to `"gzip"` - -### 2. Archive Creation (`server/filesystem/archive.go`) -- Refactored compression logic into pluggable system -- Added `createCompressor()` method that chooses format based on config -- Added `createZstdWriter()` with adaptive threading (2-4 threads max) -- Added `createGzipWriter()` with existing logic preserved -- Proper compression level mapping from config - -### 3. Archive Restoration (`server/filesystem/archive_restore.go`) -- Added automatic format detection via magic bytes -- ZSTD magic: `0x28B52FFD` -- GZIP magic: `0x1F8B` -- Graceful fallback to GZIP for unknown formats - -### 4. Backup Integration (`server/backup.go`) -- Updated `RestoreBackup()` to auto-detect compression format -- Seamless decompression without API changes -- Full backward compatibility with existing backups - -### 5. S3 Backup Fixes (`server/backup/backup_s3.go`) -- Fixed critical bug: S3 backups no longer self-delete on failure -- Fixed context cancellation bug in `generateRemoteRequest()` -- Proper success/failure handling - -### 6. Path Generation (`server/backup/backup.go`) -- Updated `Path()` method to generate appropriate file extensions: - - ZSTD: `.tar.zst` - - GZIP: `.tar.gz` - - None: `.tar` - -## Performance Benefits - -### ZSTD vs GZIP Comparison: -- **Compression Speed**: 3-5x faster than GZIP -- **Decompression Speed**: 2-3x faster than GZIP -- **Compression Ratio**: 10-20% better than GZIP -- **Memory Usage**: Lower with `LowerEncoderMem` option -- **Threading**: 2-4 threads (adaptive based on CPU count) - -## Configuration - -```yaml -system: - backups: - # Existing options remain unchanged - write_limit: 0 - compression_level: "best_speed" - - # NEW: Compression format - format: "zstd" # Options: "gzip", "zstd", "none" -``` - -## Backward Compatibility Guarantees - -### 1. Existing Backups -- All existing `.tar.gz` backups remain fully restorable -- Auto-detection handles format seamlessly -- No database migrations required -- No panel changes required - -### 2. API Compatibility -- All backup APIs remain identical -- JSON responses unchanged -- WebSocket events unchanged -- File structure unchanged - -### 3. Gradual Migration -- Default remains `"gzip"` for safety -- Can be enabled per-installation basis -- Old and new backups can coexist -- Instant rollback by changing config - -## Testing - -### Unit Tests -- Format detection tests (`archive_test.go`) -- GZIP decompression tests -- ZSTD decompression tests -- All existing filesystem tests still pass - -### Integration Tests -- Tested with real backup/restore cycles -- Verified S3 upload compatibility -- Confirmed file extension handling - -## Deployment Strategy - -### Phase 1: Deploy (Week 1) -```yaml -format: "gzip" # No change in behavior -``` - -### Phase 2: Enable ZSTD (Week 2-3) -```yaml -format: "zstd" # New backups use ZSTD -``` - -### Phase 3: Monitor & Scale (Week 4+) -- Monitor performance metrics -- Validate backup integrity -- Scale to all installations - -## Monitoring - -### Key Metrics to Track: -- Backup creation time (expect 50-70% reduction) -- Backup file sizes (expect 10-20% reduction) -- Memory usage during backups -- CPU utilization (should be similar with threading) -- S3 upload times (faster due to smaller files) - -### Success Criteria: -- Zero backup failures -- Faster backup/restore times -- Smaller storage usage -- 100% restore success rate - -## Rollback Plan - -If issues arise, instant rollback: - -```yaml -format: "gzip" # Back to original behavior -``` - -- No code changes needed -- All new backups will use GZIP -- Existing ZSTD backups remain restorable -- Zero downtime rollback - -## Technical Notes - -### Thread Management -- Maximum 4 threads for ZSTD compression -- Adaptive scaling: 2 threads (1-4 CPUs), 3 threads (5-8 CPUs), 4 threads (9+ CPUs) -- Memory-efficient encoding with `LowerEncoderMem` - -### File Extensions -- Extensions now reflect actual compression format -- Backward compatibility maintained for existing files -- Future-proof for additional formats - -### Error Handling -- Compression failures fall back to GZIP -- Decompression auto-detects format -- Graceful handling of corrupted files - -## Future Enhancements - -### Possible Additions: -- Dictionary compression for game servers -- LZ4 support for ultra-fast compression -- Backup format conversion tools -- Compression benchmarking tools - -This implementation provides significant performance improvements while maintaining production stability and backward compatibility. \ No newline at end of file diff --git a/FORK_CHANGES.md b/FORK_CHANGES.md new file mode 100644 index 000000000..69190136f --- /dev/null +++ b/FORK_CHANGES.md @@ -0,0 +1,144 @@ +# EmeraldHost Wings — Divergences from Upstream Pterodactyl Wings + +This file tracks **which changes are our own** (EmeraldHost-specific) versus upstream +[`pterodactyl/wings`](https://github.com/pterodactyl/wings). Use it during upgrades so +our customizations are **not accidentally reverted** when pulling in upstream changes. + +- **Baseline for this comparison:** upstream tag **`v1.13.1`** (`e771816`) +- **Last reviewed:** 2026-06-30 +- **Module path:** this fork is `github.com/Rene-Roscher/wings` (upstream is + `github.com/pterodactyl/wings`). Version is injected at build time via ldflags + (`-X .../system.Version=`); `system/const.go` stays `develop` and is **not** a divergence. +- **How to regenerate the picture:** + ```bash + git fetch upstream --tags + git diff --name-status HEAD # what differs + git diff HEAD -- # inspect a single file + ``` + Diff direction: in `git diff HEAD`, a `+` line is **in our fork**, + a `-` line is **upstream**. + +> ⚠️ The biggest divergence by far is the **backup subsystem** (~6 000+ lines): an +> operation registry/queue, retry, WebSocket progress and a +> heavily customized restore path. Upstream merges in `server/backup*`, `router/router_server_backup.go`, +> `server/server.go` and `sftp/server.go` will almost always conflict — resolve by **keeping ours** +> and grafting upstream's functional/security changes on top (that is exactly how v1.13.1 was merged). + +--- + +## 1. Our own changes — PRESERVE on every upgrade + +### 1.1 Module rename (mechanical, but must stay) + +| Path | What | Note | +|------|------|------| +| `go.mod` | `module github.com/pterodactyl/wings` → `github.com/Rene-Roscher/wings` | Root of the rename; every Go import of the module changes accordingly. | +| `Dockerfile`, `Makefile`, `.github/workflows/{push,release}.yaml`, `wings.go` | ldflags / `SRC_PATH` / import path use the `Rene-Roscher` module | Required so `system.Version` is injected and builds resolve. | +| ~every `*.go` file | `github.com/pterodactyl/wings/...` → `github.com/Rene-Roscher/wings/...` | **Incidental noise** — appears as conflicts on most merges but carries no behavior. Always "keep ours (Rene-Roscher)". | + +### 1.2 Backup subsystem — the largest divergence + +> Upstream's backup path is small (`s.Backup(b)` / `s.RestoreBackup(b)` in a bare goroutine). +> The fork replaced it with a queued, cancellable, progress-reporting pipeline. + +**Operation registry, queue, retry (server package)** + +| Path | What | +|------|------| +| `server/backup_operations.go` | **Fork-new.** `BackupOperationRegistry` + global `GetBackupOperationRegistry()`: concurrency-limited (`maxConcurrentBackups/Restores = 8`), queueing, cancellation. `Register()` returns the 5-tuple `(*op, ctx, cancel, err, wasQueued)` consumed by the router. `Cancel/Complete/CleanupStaleOperations` block-receive the semaphore token to release slots. Background `StartBackupOperationCleanup`. | +| `server/backup.go` | `Backup()`/`RestoreBackup()` are now compat wrappers over `BackupWithContext`/`RestoreBackupWithContext` (context + 6h/4h timeouts, atomic state transitions). `BackupWithRetry(ctx,b,2)` with exponential backoff. Restore does compression auto-detection, progress, restore-stats and per-dir `MkdirAll`+`Chown`. On panel-notify failure for a **successful** backup it **does not** delete the archive (upstream does). | +| `server/server.go` | `Server.backingUp *AtomicBool` (+ init), `AtomicStateTransition` + `ApplyAtomicStateTransition()`, `CleanupForDestroy()` cancels in-flight ops + cleans orphaned backup files, `PublishActivity()`. | +| `server/install.go` | `IsBackingUp()` / `SetBackingUp()`. (Note: `IsInProtectedState()` was deliberately **not** extended with `backingUp`.) | +| `server/power.go`, `server/errors.go` | `HandlePowerAction()` also blocks while `IsBackingUp()`, returning new sentinel `ErrServerIsBackingUp`. | +| `environment/environment.go` (+ docker state whitelist) | New states `backup`, `restore`, `backup_queued`, `restore_queued`. | + +**Progress & activity over WebSocket** + +| Path | What | +|------|------| +| `server/backup_progress.go` | **Fork-new.** `BackupProgressUpdate` WS payload `{backup_id,type,percentage,bytes_written,bytes_total}`, throttled tracker, S3 80/20 archive-vs-upload split. | +| `server/backup/download_progress.go` | **Fork-new.** `DownloadProgressReader` / `NewDownloadProgressReader` — S3 restore download progress. | +| `server/events.go` | New `BackupProgressEvent`, `DownloadProgressEvent`, `ActivityEvent`. **Frontend contract.** | +| `server/activity.go` | New `ActivityFile{Downloaded,Compressed,Decompressed,Chmod}`; `SaveActivity` also publishes `ActivityEvent` over WS. | + +**Compression & checksums** + +| Path | What | +|------|------| +| `server/backup/backup.go` | **SHA-256** checksums + `ChecksumType: "sha256"` (upstream uses **sha1**). ⚠️ **Protocol-facing** — a careless merge reverts to sha1 and breaks checksum compatibility with our Panel. `PathForLocalBackup()` helper. | +| `server/backup/compression.go` | **Fork-new.** `CompressionRegistry` (gzip/tar/none) + `IsValidBackupContentType()` — used by the router content-type gate; without it the router won't compile. | +| `server/backup/backup_local.go` | `foundPath` + extension-probing `LocateLocal` (`.tar.gz/.tar`), auto-detecting `Restore`, `CleanupBackupFilesForServer`. | +| `server/backup/backup_s3.go` | Two-phase backup reuse, success-flag cleanup (failed uploads kept for retry), orphaned-part logging, upload progress (`ProgressReader`/`ProgressTracker`), custom HTTP/1.1 transport, part-retry with 100MB memory-buffer threshold / 5GB cap, `Restore()` expects an **already-decompressed** tar stream. | +| `server/filesystem/archive.go` | Archiver no longer skips directory entries → **empty directories are preserved** in archives. `createCompressor()` refactor. | +| `server/filesystem/archive_restore.go` | **Fork-new.** `DetectCompressionFormat` (magic bytes) + `CreateDecompressor`, wired into restore. | +| `server/filesystem/compress_binary_test.go`, `server/backup/*_test.go`, `router/content_type_test.go`, `server/backup_*_test.go` | Fork-new regression suites for the above. Keep them passing. | + +**Router API (fork-only endpoints + customized handlers)** + +| Path | What | +|------|------| +| `router/router.go` | Fork-only routes `GET /backup/operations` and `DELETE /backup/:backup/cancel`. | +| `router/router_server_backup.go` | `cancelServerBackup` + `getServerBackupOperations` (fork-only). `postServerBackup` / `postServerRestoreBackup` rewritten: 409 concurrency guards, registry queueing, timeouts, panic recovery, retry, S3 download progress, and content-type via `backup.IsValidBackupContentType` (gzip+tar) instead of upstream's gzip-only check. | + +### 1.3 SFTP activity streaming + +| Path | What | +|------|------| +| `sftp/event.go`, `sftp/handler.go` | `EventPublisher` interface + `publisher` on the event handler → SFTP file actions streamed to the panel via `Server.PublishActivity` (in addition to DB persistence). | + +> The fork's previous `SmartSecurityProtector` SFTP brute-force/IP-reputation system +> (and its `sftp.security.*` config) was **removed** — `sftp/server.go` now matches upstream +> (vanilla SFTP auth) apart from the module rename. SFTP abuse protection is left to the +> network layer (firewall / fail2ban) / the Panel. + +### 1.4 Repo config + +| Path | What | +|------|------| +| `.gitignore` | Fork-added `.claude-flow/`, `.hive-mind/`, `CLAUDE.md`. Upstream will never add these — keep on merge. | +| `.github/workflows/**`, `Makefile`, `Dockerfile` | Our build/release pipeline (with the renamed module path). | + +--- + +## 2. Config divergences — confirm whether intentional + +These sit on **different** values/fields than upstream; they will re-appear in a generated +`config.yml`. Review on upgrade. + +| Path | Fork value / field | Note | +|------|--------------------|------| +| `server/backup_operations.go` | `maxConcurrentBackups/Restores = 8`; cleanup ticker 5 min / op TTL 8 h; backup 6 h / restore 4 h timeouts | Fork-chosen capacity/timeouts. | +| `server/backup_progress.go` | 250 ms throttle; S3 80/20 split; 1 MB chunking | Determines WS emission rate / S3 percentage curve. | +| `server/backup/backup_s3.go` | per-part upload `Content-Type: application/octet-stream` (upstream `application/x-gzip`) | Fork choice. Verify the Panel/S3 presigned flow tolerates it. | + +--- + +## 3. Known issues in our own fork code (tech debt) + +Every item below was **verified fork-only** (`git grep` against upstream `e771816` returns +zero hits) — upstream wings does **not** do these, so they are **our** code/behavior, not +inherited upstream defaults that can be ignored. These are not "divergences to preserve" so +much as bugs/concerns in our own additions, worth fixing rather than defending on upgrade. + +| Area | Concern (all fork-only) | +|------|---------| +| **Committed binaries** | `wings-debug`, `wings-fixed`, `wings-test` (~41 MB each) and `dist/wings_test` (~28 MB) are committed build artifacts (~150 MB total), not gitignored. Repo bloat / accidental commits. | +| **Backup cleanup scope** | `cleanupBackupFiles` (server.go) and `CleanupBackupFilesForServer` (backup_local.go) match backup files by **filename pattern only** and do **not** filter by the server's ID. Since the backup directory is shared, deleting one server can remove **other** servers' local backups. | +| **`checksum_type` label** | `server/backup.go` emits `"sha256"` in most events but still `"sha1"` in the panel-notify-failure success branch. Reconcile the labels (actual algorithm is sha256). | +| **`validateBackupContent`** | Fails a backup on any server-vs-archive file/dir count mismatch (can race a live server writing files), and computes a full SHA-256 over the **entire server tree and backup file** purely for a debug log line (perf cost on large servers). | + +--- + +## 4. NOT fork divergences — adopted from upstream (do **not** re-apply) + +These show up around our customizations but are **upstream v1.13.1** code. Treating them as +fork changes risks duplicating or mis-merging them on the next upgrade. + +| Path | Reality | +|------|---------| +| `router/router_server_backup.go` SSRF cluster — `backupRestoreHttpClient`, `validateBackupDownloadUrl`, `parseBackupUuid`, `isBlockedBackupRestoreIP`, `isExplicitlyBlockedBackupRestoreIP`, `isAllowedBackupRestoreDestination`, `isSupportedBackupRestoreContentType`, `blockedBackupRestorePrefixes`, `backupDownloadError` | **Upstream v1.13.1** backup-restore SSRF hardening. The **only** fork edit in this cluster: the restore handler calls `backup.IsValidBackupContentType` instead of `isSupportedBackupRestoreContentType` (the latter is retained only for upstream parity + its test). | +| `config/config.go` → `Backups.RestoreHostAllowlist` | **Upstream v1.13.1.** Pairs with the SSRF allowlist above. Not a fork field. | +| `server/backup/backup.go` → `validateIdentifier()` / `normalizedIdentifier()` (+ `Path()` `path.Base` fallback) | **Upstream v1.13.1** UUID hardening. The fork uses them unchanged. | +| `system/const.go` | Byte-identical to upstream (`Version = "develop"`). | +| `.github/FUNDING.yaml` (`github: [pterodactyl]`) | **Upstream default, unchanged** (`git diff e771816 HEAD` is empty). Stale for a fork (sponsorship points at upstream) but NOT our change — clean it up if desired, don't track it as a divergence. | +| `.github/workflows/release.yaml` release-bot identity (`ci@pterodactyl.io` / `Pterodactyl CI`) | **Upstream default, unchanged.** Upstream's release.yaml already sets this identity. Not our divergence. | diff --git a/WORK.md b/WORK.md deleted file mode 100644 index ccc425524..000000000 --- a/WORK.md +++ /dev/null @@ -1,19 +0,0 @@ -Wir müssen nun prüfen: - -- Backups müssen für S3 & Local grundlegend identisch funktionieren -* Backups müssen einen wirklich echten progress via Event übermitteln (Bei S3 müssen wir die uploaded parts etc. einbeziehen), bei Local ist das ganze einfacher und ist auch bereits zum großteil implementiert, jedoch muss das einmal richtig getestet werden. -- Wenn ein Backup gestartet wird, muss der jeweils richtige State dafür gesetzt werden. Also Backup/Restore wie es bisher getan wird, aber es noch nicht richtig gemacht wird, da es durch andere dinge corrupted / überschrieben wird oder auch mal garnicht gesetzt wird. Auch der Reset muss richtig funktionieren, also entweder nehmen wir den vorherigen State oder wir ermitteln den state. Ich denke bei Backup Create nehmen wir den vorherigen State und bei Restore nehmen wir den echten state, also abfragen ob running oder offline oder wie auch immer das richtig heißt. -- Backups könnten failen, das bedeutet bspw. dass die Checksum nicht identisch ist, connection weg flog oder oder oder.. Das heißt dass wir einen retry definieren, wie oft ein backup versucht wird zu machen. Default sagen wir 2 mal und danach soll er das aufgeben und auch ein Event schicken dass ein backup nicht funktioniert hat. Also als Activity Event oder so, dass man das saved hat. -- Wir müssen es schaffen, dass die Backups nicht zu lange brauchen, dennoch muss das ganze maximal verständlich bleiben und wirklich sicher -- Die Formate wie gzip, zstd sollen später weiter ausbaubar sein - sprich dass sollte eine richtige Struktur haben und wartbar sein -- Schreibe für die kritischen dinge bei Backup tests, jedoch ohne die wings dabei zu testen, sondern wirklich nur die funktionalität der einzelnen services (Also Backup Create/Restore usw. mit richtigen Datein) -- Bei den S3 Backup und den fortschritt sollten wir das ganze bewusst 80/20 machen - wir bewerten hierbei die Backup Creation mit 80% und den Upload mit 20% - Damit der Percentage wirklich echt rüberkommt, sprich die restlichen 20% werden anhand von dem upload bei S3 ausgemacht. Bei Local ist dies natürlich nicht nötig und wird dann 100% von Create bewertet. - -* Das ganze läuft in Production, muss also backwards compatibility sein. -* Vieles ist bereits in diese Richtung implementiert, aber dennoch noch nicht 100% funktionsfähig, daher agierst du als Senior-Go Experte und prüft die Logik mit deinen Agents im ersten Step, starte dafür mehrere Agents welche sich mehrere Files vornehmen und den zusammenhang / zusammenspiel technisch versuchen zu verstehen, erstelle dann einen plan das ganze prod ready zu optimieren oder ggf. neu zu entwerfen. Das ganze ist gewünscht minimal invasiv zu lösen, sofern es möglich ist! - -Mit dem Backup System ist gefordert, dass wir 100% Integrität dem Nutzer garantieren, Schnelligkeit aber auch einen reibungslosen ablauf der Software (Go) bereitstellen wollen, daher dürfen sich da keine Bugs einschleichen. - -Prüfe zudem, dass durch unsere Änderungen nichts an der Transfer Logik kaputt geht. - -Arbeite stets nach best practices, halte dich an die Source Struktur. \ No newline at end of file diff --git a/config/config.go b/config/config.go index 1173149f9..35da06e7e 100644 --- a/config/config.go +++ b/config/config.go @@ -69,60 +69,6 @@ type SftpConfiguration struct { Port int `default:"2022" json:"bind_port" yaml:"bind_port"` // If set to true, no write actions will be allowed on the SFTP server. ReadOnly bool `default:"false" yaml:"read_only"` - - // Smart brute force protection configuration - Security SftpSecurityConfiguration `yaml:"security"` -} - -// SftpSecurityConfiguration defines intelligent brute force protection settings -type SftpSecurityConfiguration struct { - // Enable/disable brute force protection - Enabled bool `default:"true" yaml:"enabled"` - - // Base thresholds for triggering blocks - Thresholds SftpSecurityThresholds `yaml:"thresholds"` - - // Block duration strategy - Blocking SftpBlockingStrategy `yaml:"blocking"` - - // Reputation system settings - Reputation SftpReputationConfig `yaml:"reputation"` -} - -// SftpSecurityThresholds defines when to start blocking -type SftpSecurityThresholds struct { - // Attempts per minute before first block (smart: 6+ = 5min) - AttemptsPerMinute int `default:"6" yaml:"attempts_per_minute"` - // Attempts per hour before escalated blocking - AttemptsPerHour int `default:"15" yaml:"attempts_per_hour"` - // Attempts per day before long-term reputation impact - AttemptsPerDay int `default:"50" yaml:"attempts_per_day"` -} - -// SftpBlockingStrategy defines how blocks escalate intelligently -type SftpBlockingStrategy struct { - // Base block duration in minutes (smart: starts at 5min) - BaseBlockMinutes int `default:"5" yaml:"base_block_minutes"` - // Exponential multiplier for repeat offenders (smart: 2x each time) - EscalationFactor float64 `default:"2.0" yaml:"escalation_factor"` - // Maximum block duration in hours (smart: caps at reasonable limit) - MaxBlockHours int `default:"24" yaml:"max_block_hours"` - // Decay factor - how much blocks reduce over time (smart: forgiveness) - DecayFactor float64 `default:"0.8" yaml:"decay_factor"` -} - -// SftpReputationConfig defines IP reputation tracking -type SftpReputationConfig struct { - // Track reputation history - Enabled bool `default:"true" yaml:"enabled"` - // Days to remember IP behavior - MemoryDays int `default:"7" yaml:"memory_days"` - // Score threshold for immediate blocking (-100 to +100) - BlockThreshold int `default:"-50" yaml:"block_threshold"` - // Good behavior bonus (successful logins) - GoodBehaviorBonus int `default:"5" yaml:"good_behavior_bonus"` - // Bad behavior penalty (failed attempts) - BadBehaviorPenalty int `default:"-10" yaml:"bad_behavior_penalty"` } // ApiConfiguration defines the configuration for the internal API that is @@ -344,14 +290,9 @@ type Backups struct { // Defaults to "best_speed" (level 1) CompressionLevel string `default:"best_speed" yaml:"compression_level"` - // Format determines the compression format used for backups. - // Available options: "gzip" (default), "zstd" - // - // zstd provides better compression ratios and faster compression/decompression - // compared to gzip, while maintaining full backward compatibility. - // - // Defaults to "gzip" for backward compatibility - Format string `default:"gzip" yaml:"format"` + // RestoreHostAllowlist allows backup restore downloads to connect to otherwise blocked + // private/internal destinations. Entries may be hostnames, IP addresses, or CIDR ranges. + RestoreHostAllowlist []string `yaml:"restore_host_allowlist"` } type Transfers struct { diff --git a/config/config_docker.go b/config/config_docker.go index 4b447de45..95501e74a 100644 --- a/config/config_docker.go +++ b/config/config_docker.go @@ -3,8 +3,11 @@ package config import ( "encoding/base64" "encoding/json" + "net/url" "sort" + "strings" + "github.com/distribution/reference" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/registry" ) @@ -105,6 +108,93 @@ func (c DockerConfiguration) ContainerLogConfig() container.LogConfig { } } +// RegistryCredentialsForImage returns registry credentials for an image only +// when the configured registry and image reference share the same registry +// identity. +func (c DockerConfiguration) RegistryCredentialsForImage(img string) (string, *RegistryConfiguration) { + named, err := reference.ParseNormalizedNamed(img) + if err != nil { + return "", nil + } + + imageDomain := strings.ToLower(reference.Domain(named)) + imagePath := reference.Path(named) + var matchedRegistry string + var matchedCredentials RegistryConfiguration + matchedScore := -1 + + for registry, cfg := range c.Registries { + domain, path, ok := parseDockerRegistryReference(registry) + if !ok || domain != imageDomain || !registryPathMatchesImage(path, imagePath) { + continue + } + + score := len(domain) + len(path) + if score > matchedScore { + matchedRegistry = registry + matchedCredentials = cfg + matchedScore = score + } + } + + if matchedScore == -1 { + return "", nil + } + + return matchedRegistry, &matchedCredentials +} + +func parseDockerRegistryReference(registry string) (string, string, bool) { + registry = strings.TrimSpace(registry) + if registry == "" { + return "", "", false + } + + if u, err := url.Parse(registry); err == nil && u.Host != "" { + p := strings.Trim(u.Path, "/") + if p == "" || p == "v1" || p == "v2" { + registry = u.Host + } else { + registry = u.Host + "/" + p + } + } + + registry = strings.Trim(registry, "/") + if registry == "" { + return "", "", false + } + + hasPath := strings.Contains(registry, "/") + ref := registry + if !hasPath { + ref += "/wings" + } + + named, err := reference.ParseNormalizedNamed(ref) + if err != nil { + return "", "", false + } + + path := "" + if hasPath { + path = reference.Path(named) + } + domain := strings.ToLower(reference.Domain(named)) + if domain == "docker.io" && (path == "v1" || path == "v2") { + path = "" + } + + return domain, path, true +} + +func registryPathMatchesImage(registryPath string, imagePath string) bool { + if registryPath == "" { + return true + } + + return imagePath == registryPath || strings.HasPrefix(imagePath, registryPath+"/") +} + // RegistryConfiguration defines the authentication credentials for a given // Docker registry. type RegistryConfiguration struct { diff --git a/config/config_docker_test.go b/config/config_docker_test.go new file mode 100644 index 000000000..72ee0962d --- /dev/null +++ b/config/config_docker_test.go @@ -0,0 +1,99 @@ +package config + +import "testing" + +func TestDockerRegistryCredentialsForImage(t *testing.T) { + cfg := DockerConfiguration{ + Registries: map[string]RegistryConfiguration{ + "registry.example.com": { + Username: "registry", + Password: "secret", + }, + "registry.example.com/team": { + Username: "team", + Password: "secret", + }, + "registry.example.com:5000": { + Username: "port", + Password: "secret", + }, + "https://index.docker.io/v1/": { + Username: "docker", + Password: "secret", + }, + }, + } + + tests := []struct { + name string + image string + username string + }{ + { + name: "registry domain", + image: "registry.example.com/project/image:latest", + username: "registry", + }, + { + name: "registry with port", + image: "registry.example.com:5000/project/image:latest", + username: "port", + }, + { + name: "registry path", + image: "registry.example.com/team/image:latest", + username: "team", + }, + { + name: "registry prefix is not domain", + image: "registry.example.com.evil/project/image:latest", + }, + { + name: "registry path prefix falls back to domain", + image: "registry.example.com/team-evil/image:latest", + username: "registry", + }, + { + name: "legacy docker hub registry", + image: "docker.io/library/busybox:latest", + username: "docker", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, registry := cfg.RegistryCredentialsForImage(tt.image) + if tt.username == "" { + if registry != nil { + t.Fatalf("expected no registry credentials, got username %q", registry.Username) + } + + return + } + + if registry == nil { + t.Fatalf("expected registry credentials for %q", tt.image) + } + + if registry.Username != tt.username { + t.Fatalf("expected username %q, got %q", tt.username, registry.Username) + } + }) + } +} + +func TestDockerRegistryPathCredentialsDoNotMatchSiblingPath(t *testing.T) { + cfg := DockerConfiguration{ + Registries: map[string]RegistryConfiguration{ + "registry.example.com/team": { + Username: "team", + Password: "secret", + }, + }, + } + + _, registry := cfg.RegistryCredentialsForImage("registry.example.com/team-evil/image:latest") + if registry != nil { + t.Fatalf("expected no registry credentials, got username %q", registry.Username) + } +} diff --git a/environment/docker/container.go b/environment/docker/container.go index ac4cd56ba..42f27c853 100644 --- a/environment/docker/container.go +++ b/environment/docker/container.go @@ -360,16 +360,9 @@ func (e *Environment) ensureImageExists(img string) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) defer cancel() - // Get a registry auth configuration from the config. - var registryAuth *config.RegistryConfiguration - for registry, c := range config.Get().Docker.Registries { - if !strings.HasPrefix(img, registry) { - continue - } - + registry, registryAuth := config.Get().Docker.RegistryCredentialsForImage(img) + if registryAuth != nil { log.WithField("registry", registry).Debug("using authentication for registry") - registryAuth = &c - break } // Get the ImagePullOptions. diff --git a/internal/ufs/fs_quota.go b/internal/ufs/fs_quota.go index cc89cbd0d..6a3d6b271 100644 --- a/internal/ufs/fs_quota.go +++ b/internal/ufs/fs_quota.go @@ -4,6 +4,7 @@ package ufs import ( + "math" "sync/atomic" ) @@ -67,14 +68,33 @@ func (fs *Quota) SetUsage(newUsage int64) int64 { // Add adds `i` to the tracked usage total. func (fs *Quota) Add(i int64) int64 { - usage := fs.Usage() + for { + usage := fs.Usage() + var next int64 + + switch { + case i > 0: + if usage > math.MaxInt64-i { + next = math.MaxInt64 + } else { + next = usage + i + } + case i < 0: + if i == math.MinInt64 { + next = 0 + } else if usage <= -i { + next = 0 + } else { + next = usage + i + } + default: + return usage + } - // If adding `i` to the usage will put us below 0, cap it. (`i` can be negative) - if usage+i < 0 { - fs.usage.Store(0) - return 0 + if fs.usage.CompareAndSwap(usage, next) { + return next + } } - return fs.usage.Add(i) } // CanFit checks if the given size can fit in the filesystem without exceeding @@ -98,14 +118,15 @@ func (fs *Quota) CanFit(size int64) bool { return true } - // If the current usage + the requested size are under the limit of the - // filesystem, allow it. - if usage+size <= limit { + if size <= 0 { return true } - // Welp, the size would exceed the limit of the filesystem, deny it. - return false + if usage >= limit { + return false + } + + return size <= limit-usage } // Remove removes the named file or (empty) directory. diff --git a/internal/ufs/fs_quota_test.go b/internal/ufs/fs_quota_test.go new file mode 100644 index 000000000..e74e101c8 --- /dev/null +++ b/internal/ufs/fs_quota_test.go @@ -0,0 +1,47 @@ +package ufs + +import ( + "math" + "testing" +) + +func TestQuotaCanFitRejectsOverflowingSize(t *testing.T) { + q := NewQuota(nil, 1<<30) + q.SetUsage(1) + + if q.CanFit(math.MaxInt64) { + t.Fatal("expected oversized write to be rejected") + } + + q.SetUsage(1 << 20) + if q.CanFit(math.MaxInt64 - 100) { + t.Fatal("expected oversized write to be rejected") + } +} + +func TestQuotaCanFitAllowsShrinkingWrite(t *testing.T) { + q := NewQuota(nil, 10) + q.SetUsage(20) + + if !q.CanFit(-5) { + t.Fatal("expected shrinking write to be allowed") + } +} + +func TestQuotaAddDoesNotResetOnPositiveOverflow(t *testing.T) { + q := NewQuota(nil, 0) + q.SetUsage(math.MaxInt64 - 1) + + if got := q.Add(10); got != math.MaxInt64 { + t.Fatalf("expected usage to saturate at MaxInt64, got %d", got) + } +} + +func TestQuotaAddClampsSubtractionAtZero(t *testing.T) { + q := NewQuota(nil, 0) + q.SetUsage(5) + + if got := q.Add(-10); got != 0 { + t.Fatalf("expected usage to clamp at zero, got %d", got) + } +} diff --git a/internal/ufs/fs_unix.go b/internal/ufs/fs_unix.go index 97224c305..b3197bfd6 100644 --- a/internal/ufs/fs_unix.go +++ b/internal/ufs/fs_unix.go @@ -170,8 +170,7 @@ func (fs *UnixFS) Chtimesat(dirfd int, name string, atime, mtime time.Time) erro set(0, atime) set(1, mtime) - // This does support `AT_SYMLINK_NOFOLLOW` as well if needed. - return ensurePathError(unix.UtimesNanoAt(dirfd, name, utimes[0:], 0), "chtimes", name) + return ensurePathError(unix.UtimesNanoAt(dirfd, name, utimes[0:], AT_SYMLINK_NOFOLLOW), "chtimes", name) } // Create creates or truncates the named file. If the file already exists, diff --git a/internal/ufs/fs_unix_test.go b/internal/ufs/fs_unix_test.go index 78739581f..1b6d11e04 100644 --- a/internal/ufs/fs_unix_test.go +++ b/internal/ufs/fs_unix_test.go @@ -13,6 +13,7 @@ import ( "slices" "strconv" "testing" + "time" "github.com/Rene-Roscher/wings/internal/ufs" ) @@ -280,15 +281,56 @@ func TestUnixFS_Lchown(t *testing.T) { } func TestUnixFS_Chtimes(t *testing.T) { - t.Parallel() - fs, err := newTestUnixFS() + tmpDir := t.TempDir() + root := filepath.Join(tmpDir, "root") + if err := os.Mkdir(root, 0o755); err != nil { + t.Fatal(err) + } + fs, err := ufs.NewUnixFS(root, false) if err != nil { t.Fatal(err) - return } - defer fs.Cleanup() - // TODO: implement + regular := filepath.Join(root, "regular") + if err := os.WriteFile(regular, []byte("regular"), 0o644); err != nil { + t.Fatal(err) + } + regularTime := time.Unix(1_700_100_000, 0) + if err := fs.Chtimes("regular", regularTime, regularTime); err != nil { + t.Fatal(err) + } + regularStat, err := os.Lstat(regular) + if err != nil { + t.Fatal(err) + } + if !regularStat.ModTime().Equal(regularTime) { + t.Fatalf("expected regular file mtime to be %s, got %s", regularTime, regularStat.ModTime()) + } + + target := filepath.Join(tmpDir, "target") + if err := os.WriteFile(target, []byte("target"), 0o644); err != nil { + t.Fatal(err) + } + original := time.Unix(1_700_000_000, 0) + if err := os.Chtimes(target, original, original); err != nil { + t.Fatal(err) + } + if err := os.Symlink(target, filepath.Join(root, "link")); err != nil { + t.Fatal(err) + } + + changed := original.Add(-24 * time.Hour) + if err := fs.Chtimes("link", changed, changed); err != nil { + t.Fatal(err) + } + + st, err := os.Lstat(target) + if err != nil { + t.Fatal(err) + } + if !st.ModTime().Equal(original) { + t.Fatalf("expected target mtime to remain %s, got %s", original, st.ModTime()) + } } func TestUnixFS_Create(t *testing.T) { diff --git a/router/content_type_test.go b/router/content_type_test.go index b49ba3801..d68649f19 100644 --- a/router/content_type_test.go +++ b/router/content_type_test.go @@ -17,12 +17,7 @@ func TestIsValidBackupContentType(t *testing.T) { {"GZIP gzip", "application/gzip", true}, {"GZIP x-compressed", "application/x-compressed", true}, {"GZIP x-gtar", "application/x-gtar", true}, - - // ZSTD formats - {"ZSTD x-zstd", "application/x-zstd", true}, - {"ZSTD zstd", "application/zstd", true}, - {"ZSTD x-zstandard", "application/x-zstandard", true}, - + // TAR formats {"TAR x-tar", "application/x-tar", true}, {"TAR tar", "application/tar", true}, diff --git a/router/router_server_backup.go b/router/router_server_backup.go index 8d34f9bb0..dcdd989b3 100644 --- a/router/router_server_backup.go +++ b/router/router_server_backup.go @@ -2,15 +2,23 @@ package router import ( "context" + stderrors "errors" "io" + "mime" + "net" "net/http" + "net/netip" + "net/url" "os" + "strings" "time" "emperror.dev/errors" "github.com/apex/log" "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/Rene-Roscher/wings/config" "github.com/Rene-Roscher/wings/environment" "github.com/Rene-Roscher/wings/router/middleware" "github.com/Rene-Roscher/wings/server" @@ -20,6 +28,22 @@ import ( // isValidBackupContentType is now replaced by backup.IsValidBackupContentType // which uses the extensible CompressionRegistry for better format support +// blockedBackupRestorePrefixes lists IP ranges that backup restore downloads are +// never allowed to reach (in addition to private/loopback/link-local ranges), +// unless a destination is explicitly permitted via the RestoreHostAllowlist. +var blockedBackupRestorePrefixes = []netip.Prefix{ + netip.MustParsePrefix("100.64.0.0/10"), + netip.MustParsePrefix("198.18.0.0/15"), +} + +// backupDownloadError is a sentinel error type used to surface backup download +// validation failures (e.g. SSRF protection) back to the API caller as 400s. +type backupDownloadError string + +func (e backupDownloadError) Error() string { + return string(e) +} + // postServerBackup performs a backup against a given server instance using the // provided backup adapter. func postServerBackup(c *gin.Context) { @@ -54,13 +78,17 @@ func postServerBackup(c *gin.Context) { if err := c.BindJSON(&data); err != nil { return } + backupUuid, ok := parseBackupUuid(c, data.Uuid) + if !ok { + return + } var adapter backup.BackupInterface switch data.Adapter { case backup.LocalBackupAdapter: - adapter = backup.NewLocal(client, data.Uuid, data.Ignore) + adapter = backup.NewLocal(client, backupUuid, data.Ignore) case backup.S3BackupAdapter: - adapter = backup.NewS3(client, data.Uuid, data.Ignore) + adapter = backup.NewS3(client, backupUuid, data.Ignore) default: middleware.CaptureAndAbort(c, errors.New("router/backups: provided adapter is not valid: "+string(data.Adapter))) return @@ -89,7 +117,7 @@ func postServerBackup(c *gin.Context) { logger.Info("registering backup operation in queue system") // ATOMIC: Register operation and get queue status atomically - _, ctx, cancel, err, wasQueued := registry.Register(s.Context(), data.Uuid, s.ID(), server.OperationTypeBackup) + _, ctx, cancel, err, wasQueued := registry.Register(s.Context(), backupUuid, s.ID(), server.OperationTypeBackup) if err != nil { logger.WithError(err).Error("failed to register backup operation") s.Events().Publish(server.DaemonMessageEvent, "Failed to register backup: " + err.Error()) @@ -106,7 +134,7 @@ func postServerBackup(c *gin.Context) { // Defer cleanup - will run AFTER backup completes defer func() { logger.Debug("backup goroutine cleanup starting") - registry.Complete(data.Uuid) + registry.Complete(backupUuid) cancel() // Cancel AFTER marking complete logger.Debug("backup goroutine cleanup completed") }() @@ -120,7 +148,7 @@ func postServerBackup(c *gin.Context) { // Send failure event to ensure frontend gets notified s.Events().Publish(server.BackupCompletedEvent, map[string]any{ - "uuid": data.Uuid, + "uuid": backupUuid, "is_successful": false, "error": err.Error(), }) @@ -176,10 +204,20 @@ func postServerRestoreBackup(c *gin.Context) { if err := c.BindJSON(&data); err != nil { return } + backupUuid, ok := parseBackupUuid(c, c.Param("backup")) + if !ok { + return + } if data.Adapter == backup.S3BackupAdapter && data.DownloadUrl == "" { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "The download_url field is required when the backup adapter is set to S3."}) return } + if data.Adapter == backup.S3BackupAdapter { + if err := validateBackupDownloadUrl(data.DownloadUrl); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + } // State management is now handled atomically within the restore function // to prevent race conditions with the operation registry @@ -196,7 +234,7 @@ func postServerRestoreBackup(c *gin.Context) { // Now that we've cleaned up the data directory if necessary, grab the backup file // and attempt to restore it into the server directory. if data.Adapter == backup.LocalBackupAdapter { - b, _, err := backup.LocateLocal(client, c.Param("backup")) + b, _, err := backup.LocateLocal(client, backupUuid) if err != nil { middleware.CaptureAndAbort(c, err) return @@ -207,7 +245,7 @@ func postServerRestoreBackup(c *gin.Context) { logger.Info("registering local restore operation in queue system") // ATOMIC: Register operation and get queue status atomically - _, ctx, cancel, err, wasQueued := registry.Register(s.Context(), c.Param("backup"), s.ID(), server.OperationTypeRestore) + _, ctx, cancel, err, wasQueued := registry.Register(s.Context(), backupUuid, s.ID(), server.OperationTypeRestore) if err != nil { logger.WithError(err).Error("failed to register restore operation") s.Events().Publish(server.DaemonMessageEvent, "Failed to register restore: " + err.Error()) @@ -226,7 +264,7 @@ func postServerRestoreBackup(c *gin.Context) { logger.WithField("panic", r).Error("restore operation panicked") } logger.Debug("local restore goroutine cleanup starting") - registry.Complete(c.Param("backup")) + registry.Complete(backupUuid) cancel() // Cancel AFTER marking complete logger.Debug("local restore goroutine cleanup completed") // Note: SetRestoring is now handled atomically within the restore function @@ -259,15 +297,11 @@ func postServerRestoreBackup(c *gin.Context) { // Since this is not a local backup we need to stream the archive and then // parse over the contents as we go in order to restore it to the server. - httpClient := &http.Client{ - Timeout: time.Hour * 2, // 2 hour timeout for large backup downloads - Transport: &http.Transport{ - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Second, - DisableKeepAlives: false, - DisableCompression: true, // Backup files are already compressed - }, - } + // + // backupRestoreHttpClient enforces SSRF protections: it refuses to connect to + // private/internal/loopback addresses (and the explicitly blocked ranges) unless + // the destination is permitted via the RestoreHostAllowlist configuration option. + httpClient := backupRestoreHttpClient() logger.WithField("download_url", data.DownloadUrl).Info("downloading backup from remote location...") // Use proper timeout to prevent indefinite hangs during backup downloads. // 2 hour timeout should be sufficient for most backup file sizes while preventing @@ -287,6 +321,13 @@ func postServerRestoreBackup(c *gin.Context) { "error": err, "duration_ms": time.Since(downloadStart).Milliseconds(), }).Error("HTTP request failed for backup download") + // Surface SSRF/validation failures from the restore HTTP client as a 400 + // to the caller instead of a generic 500. + var downloadErr backupDownloadError + if stderrors.As(err, &downloadErr) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": downloadErr.Error()}) + return + } middleware.CaptureAndAbort(c, err) return } @@ -308,14 +349,22 @@ func postServerRestoreBackup(c *gin.Context) { } } }() - + + // Reject non-200 responses (e.g. the link returned a 403/404 error page) before + // we try to interpret the body as a backup archive. The deferred close above runs + // on return since the goroutine has not taken ownership of the response yet. + if res.StatusCode != http.StatusOK { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "The provided backup link returned an invalid response status: " + res.Status}) + return + } + // Validate content types for supported backup formats using extensible compression registry contentType := res.Header.Get("Content-Type") if contentType == "" { // Accept empty content type (some S3 providers don't set it) } else if !backup.IsValidBackupContentType(contentType) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ - "error": "The provided backup link has an unsupported content type. \"" + contentType + "\" is not a supported backup format (gzip, zstd, or tar).", + "error": "The provided backup link has an unsupported content type. \"" + contentType + "\" is not a supported backup format (gzip or tar).", }) return } @@ -421,7 +470,7 @@ func postServerRestoreBackup(c *gin.Context) { // BackupRestoreCompletedEvent is now sent by RestoreBackupWithContext logger.Info("completed server restoration from S3 backup") } - }(s, c.Param("backup"), logger) + }(s, backupUuid, logger) // State cleanup handled atomically by restore operation c.Status(http.StatusAccepted) @@ -431,8 +480,11 @@ func postServerRestoreBackup(c *gin.Context) { // for consistent behavior (WORK.md compliance). If the backup is not found on the machine just return a 404 error. func deleteServerBackup(c *gin.Context) { client := middleware.ExtractApiClient(c) - backupID := c.Param("backup") - + backupID, ok := parseBackupUuid(c, c.Param("backup")) + if !ok { + return + } + // UNIFIED BEHAVIOR: Try to locate and delete backup regardless of type (Local or S3) // This ensures consistent deletion behavior between storage types (WORK.md requirement) @@ -551,3 +603,143 @@ func getServerBackupOperations(c *gin.Context) { "count": len(response), }) } + +// parseBackupUuid validates that the provided value is a canonical lowercase UUID +// and aborts the request with a 400 if it is not. This prevents path traversal and +// other malformed identifiers from reaching the backup subsystem. +func parseBackupUuid(c *gin.Context, value string) (string, bool) { + parsed, err := uuid.Parse(value) + if err == nil && len(value) == len(parsed.String()) && parsed.String() == strings.ToLower(value) { + return parsed.String(), true + } + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "The backup identifier must be a valid UUID."}) + return "", false +} + +// validateBackupDownloadUrl performs an up-front validation of an S3 backup download +// URL, rejecting non-HTTP(S) schemes and links that point directly at a blocked IP. +func validateBackupDownloadUrl(raw string) error { + parsed, err := url.Parse(raw) + if err != nil || parsed.Host == "" { + return backupDownloadError("The provided backup link is not a valid URL.") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return backupDownloadError("The provided backup link must use HTTP or HTTPS.") + } + if ip := net.ParseIP(parsed.Hostname()); ip != nil && isBlockedBackupRestoreIP(parsed.Hostname(), ip) { + return backupDownloadError("The provided backup link resolves to a blocked address.") + } + return nil +} + +// backupRestoreHttpClient returns an http.Client whose dialer resolves the target +// host and refuses to connect to private, loopback, link-local or otherwise blocked +// addresses unless the destination is explicitly permitted via RestoreHostAllowlist. +// This is the core SSRF protection for remote (S3) backup restore downloads. +func backupRestoreHttpClient() http.Client { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = nil + transport.ResponseHeaderTimeout = 30 * time.Second + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + transport.DialContext = func(ctx context.Context, network string, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + if len(ips) == 0 { + return nil, errors.New("router/backups: backup download host did not resolve to any addresses") + } + for _, resolved := range ips { + if isBlockedBackupRestoreIP(host, resolved.IP) { + return nil, backupDownloadError("The provided backup link resolves to a blocked address.") + } + } + var lastErr error + for _, resolved := range ips { + conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(resolved.IP.String(), port)) + if err == nil { + return conn, nil + } + lastErr = err + } + return nil, lastErr + } + return http.Client{ + Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return backupDownloadError("The provided backup link redirects too many times.") + } + return validateBackupDownloadUrl(req.URL.String()) + }, + } +} + +// isBlockedBackupRestoreIP reports whether a resolved IP must not be connected to for +// a backup restore download, taking the RestoreHostAllowlist into account. +func isBlockedBackupRestoreIP(host string, ip net.IP) bool { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return true + } + addr = addr.Unmap() + if !addr.IsGlobalUnicast() || addr.IsPrivate() || addr.IsLoopback() || addr.IsLinkLocalUnicast() || isExplicitlyBlockedBackupRestoreIP(addr) { + return !isAllowedBackupRestoreDestination(host, addr) + } + return false +} + +func isExplicitlyBlockedBackupRestoreIP(addr netip.Addr) bool { + for _, prefix := range blockedBackupRestorePrefixes { + if prefix.Contains(addr) { + return true + } + } + return false +} + +// isAllowedBackupRestoreDestination reports whether the given host/addr is explicitly +// permitted through the System.Backups.RestoreHostAllowlist configuration option. +func isAllowedBackupRestoreDestination(host string, addr netip.Addr) bool { + host = strings.TrimSuffix(strings.ToLower(host), ".") + for _, entry := range config.Get().System.Backups.RestoreHostAllowlist { + entry = strings.TrimSuffix(strings.ToLower(strings.TrimSpace(entry)), ".") + if entry == "" { + continue + } + if entry == host { + return true + } + if allowedAddr, err := netip.ParseAddr(entry); err == nil && allowedAddr.Unmap() == addr { + return true + } + if prefix, err := netip.ParsePrefix(entry); err == nil && prefix.Contains(addr) { + return true + } + } + return false +} + +// isSupportedBackupRestoreContentType reports whether the given Content-Type header +// value is a gzip archive. The remote-restore handler itself relies on the broader +// backup.IsValidBackupContentType (which also accepts tar/uncompressed uploads); this +// helper is retained for parity with upstream and its security test coverage. +func isSupportedBackupRestoreContentType(value string) bool { + mediaType, _, err := mime.ParseMediaType(value) + if err != nil { + mediaType = strings.TrimSpace(value) + } + switch strings.ToLower(mediaType) { + case "application/x-gzip", "application/gzip": + return true + default: + return false + } +} diff --git a/router/router_server_backup_test.go b/router/router_server_backup_test.go new file mode 100644 index 000000000..fedeb7f7d --- /dev/null +++ b/router/router_server_backup_test.go @@ -0,0 +1,337 @@ +package router + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/apex/log" + "github.com/gin-gonic/gin" + + "github.com/Rene-Roscher/wings/config" + "github.com/Rene-Roscher/wings/environment" + "github.com/Rene-Roscher/wings/events" + "github.com/Rene-Roscher/wings/internal/models" + "github.com/Rene-Roscher/wings/remote" + wserver "github.com/Rene-Roscher/wings/server" +) + +func init() { + config.Set(&config.Configuration{AuthenticationToken: "test-token"}) +} + +type backupTestRemoteClient struct { + restoreStatus chan string +} + +func (c backupTestRemoteClient) GetBackupRemoteUploadURLs(context.Context, string, int64) (remote.BackupRemoteUploadResponse, error) { + return remote.BackupRemoteUploadResponse{}, nil +} + +func (c backupTestRemoteClient) GetInstallationScript(context.Context, string) (remote.InstallationScript, error) { + return remote.InstallationScript{}, nil +} + +func (c backupTestRemoteClient) GetServerConfiguration(context.Context, string) (remote.ServerConfigurationResponse, error) { + return remote.ServerConfigurationResponse{}, nil +} + +func (c backupTestRemoteClient) GetServers(context.Context, int) ([]remote.RawServerData, error) { + return nil, nil +} + +func (c backupTestRemoteClient) ResetServersState(context.Context) error { + return nil +} + +func (c backupTestRemoteClient) SetArchiveStatus(context.Context, string, bool) error { + return nil +} + +func (c backupTestRemoteClient) SetBackupStatus(context.Context, string, remote.BackupRequest) error { + return nil +} + +func (c backupTestRemoteClient) SendRestorationStatus(_ context.Context, backup string, _ bool) error { + if c.restoreStatus != nil { + select { + case c.restoreStatus <- backup: + default: + } + } + return nil +} + +func (c backupTestRemoteClient) SetInstallationStatus(context.Context, string, remote.InstallStatusRequest) error { + return nil +} + +func (c backupTestRemoteClient) SetTransferStatus(context.Context, string, bool) error { + return nil +} + +func (c backupTestRemoteClient) ValidateSftpCredentials(context.Context, remote.SftpAuthRequest) (remote.SftpAuthResponse, error) { + return remote.SftpAuthResponse{}, nil +} + +func (c backupTestRemoteClient) SendActivityLogs(context.Context, []models.Activity) error { + return nil +} + +type backupTestEnvironment struct{} + +func (backupTestEnvironment) Type() string { return "test" } + +func (backupTestEnvironment) Config() *environment.Configuration { + return &environment.Configuration{} +} + +func (backupTestEnvironment) Events() *events.Bus { return events.NewBus() } + +func (backupTestEnvironment) Exists() (bool, error) { return true, nil } + +func (backupTestEnvironment) IsRunning(context.Context) (bool, error) { return false, nil } + +func (backupTestEnvironment) InSituUpdate() error { return nil } + +func (backupTestEnvironment) OnBeforeStart(context.Context) error { return nil } + +func (backupTestEnvironment) Start(context.Context) error { return nil } + +func (backupTestEnvironment) Stop(context.Context) error { return nil } + +func (backupTestEnvironment) WaitForStop(context.Context, time.Duration, bool) error { + return nil +} + +func (backupTestEnvironment) Terminate(context.Context, string) error { return nil } + +func (backupTestEnvironment) Destroy() error { return nil } + +func (backupTestEnvironment) ExitState() (uint32, bool, error) { return 0, false, nil } + +func (backupTestEnvironment) Create() error { return nil } + +func (backupTestEnvironment) Attach(context.Context) error { return nil } + +func (backupTestEnvironment) SendCommand(string) error { return nil } + +func (backupTestEnvironment) Readlog(int) ([]string, error) { return nil, nil } + +func (backupTestEnvironment) State() string { return environment.ProcessOfflineState } + +func (backupTestEnvironment) SetState(string) {} + +func (backupTestEnvironment) Uptime(context.Context) (int64, error) { return 0, nil } + +func (backupTestEnvironment) SetLogCallback(func([]byte)) {} + +func newBackupRestoreContext(t *testing.T, client backupTestRemoteClient, backupID string, body string) (*gin.Context, *httptest.ResponseRecorder, *wserver.Server) { + t.Helper() + + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/api/servers/server/backup/"+backupID+"/restore", strings.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Params = gin.Params{ + {Key: "server", Value: "server"}, + {Key: "backup", Value: backupID}, + } + + s, err := wserver.New(client) + if err != nil { + t.Fatal(err) + } + s.Config().Uuid = "server" + s.Environment = backupTestEnvironment{} + + c.Set("server", s) + c.Set("api_client", client) + c.Set("logger", log.WithField("test", t.Name())) + + return c, w, s +} + +func TestPostServerRestoreBackupRejectsLoopbackDownloadURL(t *testing.T) { + hit := make(chan struct{}, 1) + internal := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hit <- struct{}{} + w.Header().Set("Content-Type", "application") + _, _ = w.Write([]byte("not a gzip archive")) + })) + defer internal.Close() + downloadURL := strings.Replace(internal.URL, "127.0.0.1", "localhost", 1) + if downloadURL == internal.URL { + t.Fatalf("expected test server URL to use 127.0.0.1, got %s", internal.URL) + } + + client := backupTestRemoteClient{restoreStatus: make(chan string, 1)} + backupID := "11111111-1111-1111-1111-111111111111" + c, w, s := newBackupRestoreContext(t, client, backupID, fmt.Sprintf(`{"adapter":"s3","download_url":%q}`, downloadURL)) + defer s.CtxCancel() + + postServerRestoreBackup(c) + + if c.Writer.Status() != http.StatusBadRequest { + t.Fatalf("expected restore request to be rejected, got status %d body %s", c.Writer.Status(), w.Body.String()) + } + + select { + case <-hit: + t.Fatal("expected loopback server not to receive restore download request") + case <-time.After(100 * time.Millisecond): + } +} + +func TestPostServerRestoreBackupRejectsNonUuidBackupID(t *testing.T) { + client := backupTestRemoteClient{restoreStatus: make(chan string, 1)} + c, w, s := newBackupRestoreContext(t, client, "../target/archive", `{"adapter":"s3","download_url":"https://example.com/archive.tar.gz"}`) + defer s.CtxCancel() + + postServerRestoreBackup(c) + + if c.Writer.Status() != http.StatusBadRequest { + t.Fatalf("expected non-UUID backup id to be rejected, got status %d body %s", c.Writer.Status(), w.Body.String()) + } +} + +func TestPostServerRestoreBackupRejectsBadDownloadStatus(t *testing.T) { + setBackupRestoreAllowlist(t, []string{"127.0.0.1"}) + + remote := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-gzip") + http.Error(w, "missing", http.StatusNotFound) + })) + defer remote.Close() + + client := backupTestRemoteClient{restoreStatus: make(chan string, 1)} + backupID := "11111111-1111-1111-1111-111111111111" + c, w, s := newBackupRestoreContext(t, client, backupID, fmt.Sprintf(`{"adapter":"s3","download_url":%q}`, remote.URL)) + defer s.CtxCancel() + + postServerRestoreBackup(c) + + if c.Writer.Status() != http.StatusBadRequest { + t.Fatalf("expected restore request to be rejected, got status %d body %s", c.Writer.Status(), w.Body.String()) + } +} + +func TestBackupRestoreContentTypeValidation(t *testing.T) { + tests := map[string]bool{ + "application/gzip": true, + "application/gzip; charset=binary": true, + "application/x-gzip": true, + "application/x-gzip; charset=binary": true, + "application": false, + "gzip": false, + "text/plain": false, + "": false, + } + + for contentType, expected := range tests { + if got := isSupportedBackupRestoreContentType(contentType); got != expected { + t.Fatalf("expected content type %q support to be %v, got %v", contentType, expected, got) + } + } +} + +func TestParseBackupUuid(t *testing.T) { + tests := map[string]struct { + expected string + valid bool + }{ + "11111111-1111-1111-1111-111111111111": {expected: "11111111-1111-1111-1111-111111111111", valid: true}, + "11111111-1111-1111-1111-AAAAAAAAAAAA": {expected: "11111111-1111-1111-1111-aaaaaaaaaaaa", valid: true}, + "11111111111111111111111111111111": {valid: false}, + "../target/archive": {valid: false}, + } + + for value, test := range tests { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + got, ok := parseBackupUuid(c, value) + if ok != test.valid { + t.Fatalf("expected validity for %q to be %v, got %v", value, test.valid, ok) + } + if got != test.expected { + t.Fatalf("expected normalized backup UUID %q, got %q", test.expected, got) + } + } +} + +func TestBackupRestoreBlockedIPValidation(t *testing.T) { + setBackupRestoreAllowlist(t, nil) + + tests := map[string]bool{ + "127.0.0.1": true, + "10.0.0.1": true, + "169.254.1.1": true, + "100.64.0.1": true, + "198.18.0.1": true, + "::1": true, + "fe80::1": true, + "8.8.8.8": false, + "2606:4700::11": false, + } + + for raw, expected := range tests { + if got := isBlockedBackupRestoreIP("", net.ParseIP(raw)); got != expected { + t.Fatalf("expected blocked state for %q to be %v, got %v", raw, expected, got) + } + } +} + +func TestBackupRestoreDestinationAllowlist(t *testing.T) { + setBackupRestoreAllowlist(t, []string{ + "minio.internal", + "10.0.0.10", + "192.168.50.0/24", + }) + + tests := []struct { + name string + host string + ip string + blocked bool + }{ + {name: "hostname", host: "minio.internal", ip: "10.0.0.20", blocked: false}, + {name: "ip", host: "10.0.0.10", ip: "10.0.0.10", blocked: false}, + {name: "cidr", host: "192.168.50.10", ip: "192.168.50.10", blocked: false}, + {name: "not listed", host: "10.0.0.11", ip: "10.0.0.11", blocked: true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := isBlockedBackupRestoreIP(test.host, net.ParseIP(test.ip)); got != test.blocked { + t.Fatalf("expected blocked state for %q/%q to be %v, got %v", test.host, test.ip, test.blocked, got) + } + }) + } +} + +func TestBackupRestoreHTTPClientDoesNotLimitResponseBodyRead(t *testing.T) { + client := backupRestoreHttpClient() + if client.Timeout != 0 { + t.Fatalf("expected restore client not to set total request timeout, got %s", client.Timeout) + } +} + +func setBackupRestoreAllowlist(t *testing.T, entries []string) { + t.Helper() + + previous := config.Get() + t.Cleanup(func() { + config.Set(previous) + }) + + next := *previous + next.System.Backups.RestoreHostAllowlist = entries + config.Set(&next) +} diff --git a/server/backup.go b/server/backup.go index 6dde92e99..6cd1feeae 100644 --- a/server/backup.go +++ b/server/backup.go @@ -595,7 +595,7 @@ func (s *Server) RestoreBackupWithContext(ctx context.Context, b backup.BackupIn var estimatedTotal int64 if downloadSize > 0 { // For S3: Use actual download size with conservative multiplier for extraction - // Downloaded archives typically expand 3-4x when extracted (gzip/zstd compression) + // Downloaded archives typically expand 3-4x when extracted (gzip compression) // Using 3.2x gives good results without overshooting too much estimatedTotal = int64(float64(downloadSize) * 3.2) s.Log().WithField("download_size", downloadSize).WithField("estimated_restore_size", estimatedTotal).Info("set restore progress total from download size") @@ -802,15 +802,11 @@ func (s *Server) validateBackupIntegrity(b backup.BackupInterface) error { return errors.New("cannot read backup file header") } - // Check for GZIP magic bytes (0x1f, 0x8b) or ZSTD magic bytes (0x28, 0xb5) + // Check for GZIP magic bytes (0x1f, 0x8b) if magic[0] == 0x1f && magic[1] == 0x8b { // Valid GZIP return nil } - if magic[0] == 0x28 && magic[1] == 0xb5 { - // Valid ZSTD - return nil - } // Check for uncompressed TAR (less common but possible) if _, err := f.Seek(0, 0); err != nil { diff --git a/server/backup/backup.go b/server/backup/backup.go index d2ed0f1d8..8eaf7bc7a 100644 --- a/server/backup/backup.go +++ b/server/backup/backup.go @@ -8,9 +8,11 @@ import ( "io/fs" "os" "path" + "strings" "emperror.dev/errors" "github.com/apex/log" + "github.com/google/uuid" "github.com/mholt/archives" "golang.org/x/sync/errgroup" @@ -91,9 +93,30 @@ func (b *Backup) Identifier() string { return b.Uuid } +func (b *Backup) normalizedIdentifier() (string, error) { + parsed, err := uuid.Parse(b.Identifier()) + if err != nil || len(b.Identifier()) != len(parsed.String()) || parsed.String() != strings.ToLower(b.Identifier()) { + return "", errors.New("backup: identifier must be a valid UUID") + } + return parsed.String(), nil +} + +func (b *Backup) validateIdentifier() error { + identifier, err := b.normalizedIdentifier() + if err != nil { + return err + } + b.Uuid = identifier + return nil +} + // Path returns the path for this specific backup. func (b *Backup) Path() string { - return path.Join(config.Get().System.BackupDirectory, b.Identifier()+".tar.gz") + identifier, err := b.normalizedIdentifier() + if err != nil { + identifier = path.Base(b.Identifier()) + } + return path.Join(config.Get().System.BackupDirectory, identifier+".tar.gz") } // PathForLocalBackup returns the path for a LocalBackup, checking for foundPath override @@ -106,6 +129,9 @@ func (b *Backup) PathForLocalBackup(foundPath string) string { // Size returns the size of the generated backup. func (b *Backup) Size() (int64, error) { + if err := b.validateIdentifier(); err != nil { + return 0, err + } st, err := os.Stat(b.Path()) if err != nil { return 0, err @@ -116,6 +142,9 @@ func (b *Backup) Size() (int64, error) { // Checksum returns the SHA256 checksum of a backup. func (b *Backup) Checksum() ([]byte, error) { + if err := b.validateIdentifier(); err != nil { + return nil, err + } h := sha256.New() f, err := os.Open(b.Path()) diff --git a/server/backup/backup_local.go b/server/backup/backup_local.go index e76058f03..975a0ccde 100644 --- a/server/backup/backup_local.go +++ b/server/backup/backup_local.go @@ -43,7 +43,10 @@ func NewLocal(client remote.Client, uuid string, ignore string) *LocalBackup { // ENHANCED: Now supports finding backups with different extensions (backward compatibility) func LocateLocal(client remote.Client, uuid string) (*LocalBackup, os.FileInfo, error) { b := NewLocal(client, uuid, "") - + if err := b.validateIdentifier(); err != nil { + return nil, nil, err + } + // Try current config format first (new behavior) st, err := os.Stat(b.Path()) if err == nil { @@ -56,7 +59,7 @@ func LocateLocal(client remote.Client, uuid string) (*LocalBackup, os.FileInfo, // BACKWARD COMPATIBILITY: Try other formats if current format not found if os.IsNotExist(err) { // Try all possible extensions for backward compatibility - possibleExtensions := []string{".tar.gz", ".tar.zst", ".tar"} + possibleExtensions := []string{".tar.gz", ".tar"} baseDir := config.Get().System.BackupDirectory for _, ext := range possibleExtensions { @@ -88,6 +91,9 @@ func (b *LocalBackup) Path() string { // Remove removes a backup from the system. func (b *LocalBackup) Remove() error { + if err := b.validateIdentifier(); err != nil { + return err + } return os.Remove(b.Path()) } @@ -99,6 +105,9 @@ func (b *LocalBackup) WithLogContext(c map[string]interface{}) { // Generate generates a backup of the selected files and pushes it to the // defined location for this instance. func (b *LocalBackup) Generate(ctx context.Context, fsys *filesystem.Filesystem, ignore string) (*ArchiveDetails, error) { + if err := b.validateIdentifier(); err != nil { + return nil, err + } a := &filesystem.Archive{ Filesystem: fsys, Ignore: ignore, @@ -120,6 +129,9 @@ func (b *LocalBackup) Generate(ctx context.Context, fsys *filesystem.Filesystem, // Restore will walk over the archive and call the callback function for each // file encountered. func (b *LocalBackup) Restore(ctx context.Context, _ io.Reader, callback RestoreCallback) error { + if err := b.validateIdentifier(); err != nil { + return err + } f, err := os.Open(b.Path()) if err != nil { return err @@ -260,7 +272,7 @@ func CleanupBackupFilesForServer(serverID string) error { func isBackupFile(filename string) bool { // Common backup file extensions backupExtensions := []string{ - ".tar.gz", ".tar.zst", ".tar", ".gz", ".zst", + ".tar.gz", ".tar", ".gz", } lowerName := strings.ToLower(filename) diff --git a/server/backup/backup_s3.go b/server/backup/backup_s3.go index 3b26d0be8..43e3f7a9a 100644 --- a/server/backup/backup_s3.go +++ b/server/backup/backup_s3.go @@ -66,6 +66,9 @@ func (s *S3Backup) WithUploadCallback(callback func()) *S3Backup { // Remove removes a backup from the system. func (s *S3Backup) Remove() error { + if err := s.validateIdentifier(); err != nil { + return err + } return os.Remove(s.Path()) } @@ -77,9 +80,13 @@ func (s *S3Backup) WithLogContext(c map[string]interface{}) { // Generate creates a new backup on the disk, moves it into the S3 bucket via // the provided presigned URL, and then deletes the backup from the disk. func (s *S3Backup) Generate(ctx context.Context, fsys *filesystem.Filesystem, ignore string) (*ArchiveDetails, error) { + if err := s.validateIdentifier(); err != nil { + return nil, err + } + var uploadedParts []remote.BackupPart success := false - + defer func() { if success { s.Remove() // Only remove on successful upload @@ -157,7 +164,7 @@ func (s *S3Backup) Restore(ctx context.Context, r io.Reader, callback RestoreCal // CRITICAL: The reader provided here is ALREADY DECOMPRESSED by the server layer! // The server's RestoreBackupWithContext method handles: - // 1. Format detection (gzip, zstd, etc.) + // 1. Format detection (gzip, etc.) // 2. Decompression // 3. Passing us the clean TAR stream // diff --git a/server/backup/backup_test.go b/server/backup/backup_test.go new file mode 100644 index 000000000..f94944a5f --- /dev/null +++ b/server/backup/backup_test.go @@ -0,0 +1,103 @@ +package backup + +import ( + "bytes" + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/Rene-Roscher/wings/config" + "github.com/Rene-Roscher/wings/server/filesystem" +) + +func TestBackupGenerateRequiresUuidIdentifier(t *testing.T) { + tests := map[string]func(string) BackupInterface{ + "local": func(identifier string) BackupInterface { + return NewLocal(nil, identifier, "") + }, + "s3": func(identifier string) BackupInterface { + return NewS3(nil, identifier, "") + }, + } + + for name, createBackup := range tests { + t.Run(name, func(t *testing.T) { + testBackupGenerateRequiresUuidIdentifier(t, createBackup) + }) + } +} + +func TestBackupPathUsesBackupDirectory(t *testing.T) { + backupDir := t.TempDir() + config.Set(&config.Configuration{ + AuthenticationToken: "test-token", + System: config.SystemConfiguration{ + BackupDirectory: backupDir, + }, + }) + + for _, identifier := range []string{ + "11111111-1111-1111-1111-111111111111", + "../target/archive", + "nested/archive", + } { + b := NewLocal(nil, identifier, "") + rel, err := filepath.Rel(backupDir, b.Path()) + if err != nil { + t.Fatal(err) + } + if filepath.IsAbs(rel) || rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + t.Fatalf("expected backup path %q to remain under %q", b.Path(), backupDir) + } + } +} + +func testBackupGenerateRequiresUuidIdentifier(t *testing.T, createBackup func(string) BackupInterface) { + t.Helper() + + root := t.TempDir() + backupDir := filepath.Join(root, "backups") + targetDir := filepath.Join(root, "target") + serverDir := filepath.Join(root, "server") + for _, dir := range []string{backupDir, targetDir, serverDir} { + if err := os.MkdirAll(dir, 0o700); err != nil { + t.Fatal(err) + } + } + config.Set(&config.Configuration{ + AuthenticationToken: "test-token", + System: config.SystemConfiguration{ + BackupDirectory: backupDir, + }, + }) + + if err := os.WriteFile(filepath.Join(serverDir, "file.txt"), []byte("server data"), 0o600); err != nil { + t.Fatal(err) + } + fsys, err := filesystem.New(serverDir, 0, nil) + if err != nil { + t.Fatal(err) + } + + existingArchive := filepath.Join(targetDir, "archive.tar.gz") + existingArchiveContents := []byte("existing archive") + if err := os.WriteFile(existingArchive, existingArchiveContents, 0o600); err != nil { + t.Fatal(err) + } + + b := createBackup("../target/archive") + if _, err := b.Generate(context.Background(), fsys, ""); err == nil { + t.Fatal("expected invalid backup identifier to be rejected") + } + + got, err := os.ReadFile(existingArchive) + if err != nil { + t.Fatal(err) + } + if bytes.Equal(got, existingArchiveContents) { + return + } + t.Fatal("expected backup generation not to overwrite existing archive") +} diff --git a/server/backup/compression.go b/server/backup/compression.go index ad63e0346..a7f402aa4 100644 --- a/server/backup/compression.go +++ b/server/backup/compression.go @@ -9,7 +9,6 @@ type CompressionFormat string const ( CompressionGzip CompressionFormat = "gzip" - CompressionZstd CompressionFormat = "zstd" CompressionTar CompressionFormat = "tar" CompressionNone CompressionFormat = "none" ) @@ -51,21 +50,6 @@ func (g *gzipAdapter) ContentTypes() []string { func (g *gzipAdapter) IsSupported() bool { return true } func (g *gzipAdapter) Description() string { return "GZIP compression" } -// zstdAdapter implements CompressionAdapter for ZSTD format -type zstdAdapter struct{} - -func (z *zstdAdapter) Format() CompressionFormat { return CompressionZstd } -func (z *zstdAdapter) Extension() string { return ".zst" } -func (z *zstdAdapter) ContentTypes() []string { - return []string{ - "application/x-zstd", - "application/zstd", - "application/x-zstandard", - } -} -func (z *zstdAdapter) IsSupported() bool { return true } -func (z *zstdAdapter) Description() string { return "ZSTD compression (high performance)" } - // tarAdapter implements CompressionAdapter for TAR format type tarAdapter struct{} @@ -107,7 +91,6 @@ func NewCompressionRegistry() *CompressionRegistry { // Register default compression formats registry.Register(&gzipAdapter{}) - registry.Register(&zstdAdapter{}) registry.Register(&tarAdapter{}) registry.Register(&noneAdapter{}) diff --git a/server/filesystem/archive_restore.go b/server/filesystem/archive_restore.go index 996692a4f..89cb5e436 100644 --- a/server/filesystem/archive_restore.go +++ b/server/filesystem/archive_restore.go @@ -14,7 +14,6 @@ type CompressionFormat int const ( CompressionUnknown CompressionFormat = iota CompressionGzip - CompressionZstd // Kept for backward compatibility but no longer supported CompressionNone ) @@ -37,9 +36,6 @@ func DetectCompressionFormat(reader io.ReadCloser) (CompressionFormat, io.ReadCl return CompressionGzip, io.NopCloser(peekReader), errors.New("backup: insufficient data for format detection") } - // ZSTD is no longer supported - skip detection - // (Previously checked for 0x28B52FFD magic bytes) - // GZIP magic: 0x1F8B (validate both bytes for security) if len(header) >= 2 && header[0] == 0x1F && header[1] == 0x8B { return CompressionGzip, io.NopCloser(peekReader), nil @@ -57,11 +53,6 @@ func CreateDecompressor(reader io.ReadCloser, format CompressionFormat) (io.Read } switch format { - case CompressionZstd: - // ZSTD is no longer supported - reader.Close() - return nil, errors.New("backup: ZSTD compression is no longer supported") - case CompressionGzip: gzReader, err := gzip.NewReader(reader) if err != nil { diff --git a/server/filesystem/archive_stream_test.go b/server/filesystem/archive_stream_test.go new file mode 100644 index 000000000..26e3fe964 --- /dev/null +++ b/server/filesystem/archive_stream_test.go @@ -0,0 +1,122 @@ +package filesystem + +import ( + "context" + iofs "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "testing" + + . "github.com/franela/goblin" + "github.com/mholt/archives" +) + +func TestArchive_Stream(t *testing.T) { + g := Goblin(t) + fs, rfs := NewFs() + + g.Describe("Archive", func() { + g.AfterEach(func() { + // Reset the filesystem after each run. + _ = fs.TruncateRootDirectory() + }) + + g.It("creates archive with intended files", func() { + g.Assert(fs.CreateDirectory("test", "/")).IsNil() + g.Assert(fs.CreateDirectory("test2", "/")).IsNil() + + r := strings.NewReader("hello, world!\n") + err := fs.Write("test/file.txt", r, r.Size(), 0o644) + g.Assert(err).IsNil() + + r = strings.NewReader("hello, world!\n") + err = fs.Write("test2/file.txt", r, r.Size(), 0o644) + g.Assert(err).IsNil() + + r = strings.NewReader("hello, world!\n") + err = fs.Write("test_file.txt", r, r.Size(), 0o644) + g.Assert(err).IsNil() + + r = strings.NewReader("hello, world!\n") + err = fs.Write("test_file.txt.old", r, r.Size(), 0o644) + g.Assert(err).IsNil() + + a := &Archive{ + Filesystem: fs, + Files: []string{ + "test", + "test_file.txt", + }, + } + + // Create the archive. + archivePath := filepath.Join(rfs.root, "archive.tar.gz") + g.Assert(a.Create(context.Background(), archivePath)).IsNil() + + // Ensure the archive exists. + _, err = os.Stat(archivePath) + g.Assert(err).IsNil() + + // Open the archive. + genericFs, err := archives.FileSystem(context.Background(), archivePath, nil) + g.Assert(err).IsNil() + + // Assert that we are opening an archive. + afs, ok := genericFs.(iofs.ReadDirFS) + g.Assert(ok).IsTrue() + + // Get the names of the files recursively from the archive. + files, err := getFiles(afs, ".") + g.Assert(err).IsNil() + + // Ensure the files in the archive match what we are expecting. + expected := []string{ + "test_file.txt", + "test/file.txt", + } + + // Sort the slices to ensure the comparison never fails if the + // contents are sorted differently. + sort.Strings(expected) + sort.Strings(files) + + g.Assert(files).Equal(expected) + }) + }) +} + +func getFiles(f iofs.ReadDirFS, name string) ([]string, error) { + var v []string + + entries, err := f.ReadDir(name) + if err != nil { + return nil, err + } + + for _, e := range entries { + entryName := e.Name() + if name != "." { + entryName = filepath.Join(name, entryName) + } + + if e.IsDir() { + files, err := getFiles(f, entryName) + if err != nil { + return nil, err + } + + if files == nil { + return nil, nil + } + + v = append(v, files...) + continue + } + + v = append(v, entryName) + } + + return v, nil +} diff --git a/server/filesystem/archive_system.go b/server/filesystem/archive_system.go deleted file mode 100644 index a6add2fba..000000000 --- a/server/filesystem/archive_system.go +++ /dev/null @@ -1,215 +0,0 @@ -package filesystem - -import ( - "context" - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "strings" - - "emperror.dev/errors" - "github.com/apex/log" - "github.com/Rene-Roscher/wings/config" -) - -// CreateArchiveUsingSystemTar creates a backup using the system tar command -// This bypasses all Go library issues and uses the proven system tools -func (a *Archive) CreateArchiveUsingSystemTar(ctx context.Context, dst string) error { - if a.Filesystem == nil { - return errors.New("filesystem: archive.Filesystem is unset") - } - - // Determine compression flag based on file extension - var compressionFlag string - switch { - case strings.HasSuffix(dst, ".tar.zst"): - compressionFlag = "--zstd" - case strings.HasSuffix(dst, ".tar.gz"): - compressionFlag = "-z" - default: - compressionFlag = "" // No compression - } - - // Build tar command - args := []string{ - "-c", // Create archive - "-f", dst, // Output file - } - - if compressionFlag != "" { - args = append(args, compressionFlag) - } - - // Add base directory - args = append(args, "-C", a.Filesystem.Path()) - - // If specific files are provided, add them - if len(a.Files) > 0 { - for _, file := range a.Files { - // Strip leading slash and filesystem path - cleanFile := strings.TrimPrefix(file, a.Filesystem.Path()) - cleanFile = strings.TrimPrefix(cleanFile, "/") - if cleanFile != "" { - args = append(args, cleanFile) - } - } - } else { - // Archive everything in the base directory - if a.BaseDirectory != "" { - args = append(args, a.BaseDirectory) - } else { - args = append(args, ".") - } - } - - // Create tar command - cmd := exec.CommandContext(ctx, "tar", args...) - cmd.Dir = a.Filesystem.Path() - - // Set up ignore file if provided - if a.Ignore != "" { - // Write ignore patterns to exclude file - excludeFile := filepath.Join("/tmp", fmt.Sprintf("backup-exclude-%d", os.Getpid())) - if err := os.WriteFile(excludeFile, []byte(a.Ignore), 0600); err != nil { - return errors.Wrap(err, "failed to write exclude file") - } - defer os.Remove(excludeFile) - - // Add exclude flag - cmd.Args = append(cmd.Args[:2], append([]string{"--exclude-from=" + excludeFile}, cmd.Args[2:]...)...) - } - - log.WithField("command", cmd.String()).Debug("executing system tar command") - - // Execute the command - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "tar command failed: %s", string(output)) - } - - return nil -} - -// ExtractArchiveUsingSystemTar extracts a backup using the system tar command -func ExtractArchiveUsingSystemTar(ctx context.Context, src string, dst string) error { - // Determine decompression flag based on file extension - var decompressionFlag string - switch { - case strings.HasSuffix(src, ".tar.zst"): - decompressionFlag = "--zstd" - case strings.HasSuffix(src, ".tar.gz"): - decompressionFlag = "-z" - case strings.HasSuffix(src, ".tar.xz"): - decompressionFlag = "-J" - case strings.HasSuffix(src, ".tar.bz2"): - decompressionFlag = "-j" - default: - decompressionFlag = "" // No decompression - } - - // Build tar command - args := []string{ - "-x", // Extract archive - "-f", src, // Input file - "-C", dst, // Extract to directory - "--preserve-permissions", // CRITICAL: Preserve file permissions! - "--preserve", // Preserve all attributes - } - - if decompressionFlag != "" { - args = append(args, decompressionFlag) - } - - // Create tar command - cmd := exec.CommandContext(ctx, "tar", args...) - - log.WithField("command", cmd.String()).Debug("executing system tar extract command") - - // Execute the command - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "tar extract failed: %s", string(output)) - } - - return nil -} - -// StreamArchiveUsingSystemTar streams archive creation using system tar -func (a *Archive) StreamArchiveUsingSystemTar(ctx context.Context, w io.Writer) error { - if a.Filesystem == nil { - return errors.New("filesystem: archive.Filesystem is unset") - } - - // Use zstd command for compression if needed - var cmd *exec.Cmd - - // Build tar command (without compression) - tarArgs := []string{ - "-c", // Create archive - "-f", "-", // Output to stdout - "-C", a.Filesystem.Path(), - } - - // Add files or directory - if len(a.Files) > 0 { - for _, file := range a.Files { - cleanFile := strings.TrimPrefix(file, a.Filesystem.Path()) - cleanFile = strings.TrimPrefix(cleanFile, "/") - if cleanFile != "" { - tarArgs = append(tarArgs, cleanFile) - } - } - } else if a.BaseDirectory != "" { - tarArgs = append(tarArgs, a.BaseDirectory) - } else { - tarArgs = append(tarArgs, ".") - } - - // Check if we need compression - if config.Get().System.Backups.Format == "zstd" { - // Pipe tar through zstd - tarCmd := exec.CommandContext(ctx, "tar", tarArgs...) - tarCmd.Dir = a.Filesystem.Path() - - zstdCmd := exec.CommandContext(ctx, "zstd", "-c", "-T0") // -T0 uses all CPU cores - - // Create pipe - pipe, err := tarCmd.StdoutPipe() - if err != nil { - return errors.Wrap(err, "failed to create pipe") - } - - zstdCmd.Stdin = pipe - zstdCmd.Stdout = w - - // Start both commands - if err := tarCmd.Start(); err != nil { - return errors.Wrap(err, "failed to start tar") - } - if err := zstdCmd.Start(); err != nil { - return errors.Wrap(err, "failed to start zstd") - } - - // Wait for both to complete - if err := tarCmd.Wait(); err != nil { - return errors.Wrap(err, "tar command failed") - } - if err := zstdCmd.Wait(); err != nil { - return errors.Wrap(err, "zstd command failed") - } - } else { - // Just tar with gzip - tarArgs[1] = "-czf" // Add gzip compression - cmd = exec.CommandContext(ctx, "tar", tarArgs...) - cmd.Dir = a.Filesystem.Path() - cmd.Stdout = w - - if err := cmd.Run(); err != nil { - return errors.Wrap(err, "tar command failed") - } - } - - return nil -} \ No newline at end of file diff --git a/server/filesystem/archive_test.go b/server/filesystem/archive_test.go index c348894d7..389dd7945 100644 --- a/server/filesystem/archive_test.go +++ b/server/filesystem/archive_test.go @@ -5,8 +5,6 @@ import ( "compress/gzip" "io" "testing" - - "github.com/klauspost/compress/zstd" ) func TestDetectCompressionFormat(t *testing.T) { @@ -20,11 +18,6 @@ func TestDetectCompressionFormat(t *testing.T) { data: []byte{0x1F, 0x8B, 0x08, 0x00}, // GZIP magic expectedFormat: CompressionGzip, }, - { - name: "ZSTD format (no longer supported, falls back to GZIP)", - data: []byte{0x28, 0xB5, 0x2F, 0xFD}, // ZSTD magic - expectedFormat: CompressionGzip, // Falls back to GZIP since ZSTD is not supported - }, { name: "Unknown format defaults to GZIP", data: []byte{0x00, 0x00, 0x00, 0x00}, @@ -78,36 +71,4 @@ func TestCreateDecompressor(t *testing.T) { t.Errorf("GZIP decompression failed: got %s, want 'test data'", string(data)) } }) - - // Test ZSTD decompressor (should fail as ZSTD is no longer supported) - t.Run("ZSTD decompressor", func(t *testing.T) { - var buf bytes.Buffer - zw, err := zstd.NewWriter(&buf) - if err != nil { - t.Fatal(err) - } - _, err = zw.Write([]byte("test data")) - if err != nil { - t.Fatal(err) - } - zw.Close() - - reader := io.NopCloser(bytes.NewReader(buf.Bytes())) - decompressor, err := CreateDecompressor(reader, CompressionZstd) - - // ZSTD is no longer supported, should return an error - if err == nil { - if decompressor != nil { - decompressor.Close() - } - t.Error("CreateDecompressor() should return error for ZSTD format (no longer supported)") - return - } - - // Verify the error message contains expected text - expectedErrMsg := "ZSTD compression is no longer supported" - if !bytes.Contains([]byte(err.Error()), []byte(expectedErrMsg)) { - t.Errorf("CreateDecompressor() error = %v, should contain %q", err, expectedErrMsg) - } - }) } diff --git a/server/filesystem/compress.go b/server/filesystem/compress.go index 95e88c3a1..485606eb2 100644 --- a/server/filesystem/compress.go +++ b/server/filesystem/compress.go @@ -5,6 +5,7 @@ import ( "fmt" "io" iofs "io/fs" + "math" "path" "path/filepath" "strings" @@ -43,7 +44,7 @@ func (fs *Filesystem) CompressFiles(dir string, paths []string) (ufs.FileInfo, e if err := a.Stream(context.Background(), cw); err != nil { return nil, err } - if !fs.unixFS.CanFit(cw.BytesWritten()) { + if cw.BytesWritten() < 0 || !fs.unixFS.CanFit(cw.BytesWritten()) { _ = fs.unixFS.Remove(d) return nil, newFilesystemError(ErrCodeDiskSpace, nil) } @@ -132,9 +133,19 @@ func (fs *Filesystem) SpaceAvailableForDecompression(ctx context.Context, dir st if err != nil { return err } - if !fs.unixFS.CanFit(size.Add(info.Size())) { + fileSize := info.Size() + if fileSize <= 0 { + return nil + } + current := size.Load() + if fileSize > math.MaxInt64-current { + return newFilesystemError(ErrCodeDiskSpace, nil) + } + next := current + fileSize + if !fs.unixFS.CanFit(next) { return newFilesystemError(ErrCodeDiskSpace, nil) } + size.Store(next) return nil } }) diff --git a/server/filesystem/disk_space.go b/server/filesystem/disk_space.go index fc2bceff6..9d2003b00 100644 --- a/server/filesystem/disk_space.go +++ b/server/filesystem/disk_space.go @@ -208,6 +208,32 @@ func (fs *Filesystem) HasSpaceFor(size int64) error { return nil } +func (fs *Filesystem) reserveDisk(size int64) error { + if size <= 0 { + return nil + } + + fs.mu.Lock() + defer fs.mu.Unlock() + + if err := fs.HasSpaceFor(size); err != nil { + return err + } + fs.unixFS.Add(size) + return nil +} + +func (fs *Filesystem) adjustDisk(size int64) int64 { + if size == 0 { + return fs.CachedUsage() + } + + fs.mu.Lock() + defer fs.mu.Unlock() + + return fs.unixFS.Add(size) +} + // Updates the disk usage for the Filesystem instance. func (fs *Filesystem) addDisk(i int64) int64 { return fs.unixFS.Add(i) diff --git a/server/filesystem/filesystem.go b/server/filesystem/filesystem.go index 24e324718..06b4486e7 100644 --- a/server/filesystem/filesystem.go +++ b/server/filesystem/filesystem.go @@ -93,7 +93,19 @@ func (fs *Filesystem) UnixFS() *ufs.UnixFS { // already. If it is present, the file is opened using the defaults which will truncate // the contents. The opened file is then returned to the caller. func (fs *Filesystem) Touch(p string, flag int) (ufs.File, error) { - return fs.unixFS.Touch(p, flag, 0o644) + var currentSize int64 + st, err := fs.unixFS.Stat(p) + if err != nil && !errors.Is(err, ufs.ErrNotExist) { + return nil, err + } else if err == nil && !st.IsDir() { + currentSize = st.Size() + } + + file, err := fs.unixFS.Touch(p, flag, 0o644) + if err != nil { + return nil, err + } + return newQuotaFile(fs, file, currentSize), nil } // Writefile writes a file to the system. If the file does not already exist one diff --git a/server/filesystem/filesystem_test.go b/server/filesystem/filesystem_test.go index 9c90baeb3..afec934da 100644 --- a/server/filesystem/filesystem_test.go +++ b/server/filesystem/filesystem_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "errors" + "math" "math/rand" "os" "path/filepath" @@ -111,10 +112,93 @@ func TestFilesystem_Openfile(t *testing.T) { }) } -func TestFilesystem_Writefile(t *testing.T) { +func TestFilesystem_Touch(t *testing.T) { g := Goblin(t) fs, _ := NewFs() + g.Describe("Touch", func() { + g.It("enforces disk limits while writing to the returned handle", func() { + fs.SetDiskLimit(10) + + f, err := fs.Touch("quota.txt", ufs.O_RDWR|ufs.O_TRUNC) + g.Assert(err).IsNil() + defer f.Close() + + n, err := f.WriteAt([]byte("1234567890"), 0) + g.Assert(err).IsNil() + g.Assert(n).Equal(10) + g.Assert(fs.CachedUsage()).Equal(int64(10)) + + n, err = f.WriteAt([]byte("1"), 10) + g.Assert(err).IsNotNil() + g.Assert(n).Equal(0) + g.Assert(IsErrorCode(err, ErrCodeDiskSpace)).IsTrue() + g.Assert(fs.CachedUsage()).Equal(int64(10)) + }) + + g.It("enforces disk limits while sequentially writing to the returned handle", func() { + fs.SetDiskLimit(10) + + f, err := fs.Touch("quota.txt", ufs.O_RDWR|ufs.O_TRUNC) + g.Assert(err).IsNil() + defer f.Close() + + n, err := f.Write([]byte("1234567890")) + g.Assert(err).IsNil() + g.Assert(n).Equal(10) + + n, err = f.Write([]byte("1")) + g.Assert(err).IsNotNil() + g.Assert(n).Equal(0) + g.Assert(IsErrorCode(err, ErrCodeDiskSpace)).IsTrue() + g.Assert(fs.CachedUsage()).Equal(int64(10)) + }) + + g.It("updates disk usage when a truncated file is closed smaller", func() { + r := bytes.NewReader([]byte("1234567890")) + err := fs.Write("quota.txt", r, r.Size(), 0o644) + g.Assert(err).IsNil() + g.Assert(fs.CachedUsage()).Equal(int64(10)) + + f, err := fs.Touch("quota.txt", ufs.O_RDWR|ufs.O_TRUNC) + g.Assert(err).IsNil() + + n, err := f.WriteAt([]byte("1234"), 0) + g.Assert(err).IsNil() + g.Assert(n).Equal(4) + + err = f.Close() + g.Assert(err).IsNil() + g.Assert(fs.CachedUsage()).Equal(int64(4)) + }) + + g.It("does not reset disk usage after a failed huge-offset write", func() { + const usage = int64(5 * 1024 * 1024) + fs.SetDiskLimit(10 * 1024 * 1024) + fs.unixFS.SetUsage(usage) + + f, err := fs.Touch("quota.txt", ufs.O_RDWR|ufs.O_TRUNC) + g.Assert(err).IsNil() + + n, err := f.WriteAt([]byte("x"), math.MaxInt64) + g.Assert(err).IsNotNil() + g.Assert(n).Equal(0) + + err = f.Close() + g.Assert(err).IsNil() + g.Assert(fs.CachedUsage()).Equal(usage) + }) + + g.AfterEach(func() { + _ = fs.TruncateRootDirectory() + }) + }) +} + +func TestFilesystem_Writefile(t *testing.T) { + g := Goblin(t) + fs, rfs := NewFs() + g.Describe("Open and WriteFile", func() { buf := &bytes.Buffer{} @@ -181,6 +265,19 @@ func TestFilesystem_Writefile(t *testing.T) { g.Assert(IsErrorCode(err, ErrCodeDiskSpace)).IsTrue() }) + g.It("cannot write a file whose claimed size overflows the quota check", func() { + fs.SetDiskLimit(1024) + fs.unixFS.SetUsage(1) + + r := bytes.NewReader([]byte("small body")) + err := fs.Write("overflow.txt", r, math.MaxInt64, 0o644) + g.Assert(err).IsNotNil() + g.Assert(IsErrorCode(err, ErrCodeDiskSpace)).IsTrue() + + _, err = rfs.StatServerFile("overflow.txt") + g.Assert(errors.Is(err, os.ErrNotExist)).IsTrue("err is not os.ErrNotExist") + }) + g.It("truncates the file when writing new contents", func() { r := bytes.NewReader([]byte("original data")) err := fs.Write("test.txt", r, r.Size(), 0o644) diff --git a/server/filesystem/quota_file.go b/server/filesystem/quota_file.go new file mode 100644 index 000000000..ce924de1b --- /dev/null +++ b/server/filesystem/quota_file.go @@ -0,0 +1,120 @@ +package filesystem + +import ( + "io" + "math" + "sync" + + "github.com/Rene-Roscher/wings/internal/ufs" +) + +type quotaFile struct { + ufs.File + + fs *Filesystem + mu sync.Mutex + size int64 +} + +func newQuotaFile(fs *Filesystem, file ufs.File, size int64) ufs.File { + return "aFile{File: file, fs: fs, size: size} +} + +func (f *quotaFile) Write(p []byte) (int, error) { + if len(p) == 0 { + return f.File.Write(p) + } + + f.mu.Lock() + defer f.mu.Unlock() + + off, err := f.File.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + + return f.writeAtLocked(p, off, func() (int, error) { + return f.File.Write(p) + }) +} + +func (f *quotaFile) WriteAt(p []byte, off int64) (int, error) { + if off < 0 || len(p) == 0 { + return f.File.WriteAt(p, off) + } + + f.mu.Lock() + defer f.mu.Unlock() + + return f.writeAtLocked(p, off, func() (int, error) { + return f.File.WriteAt(p, off) + }) +} + +func (f *quotaFile) writeAtLocked(p []byte, off int64, write func() (int, error)) (int, error) { + previousSize := f.size + end, ok := quotaWriteEnd(off, len(p)) + if !ok { + return 0, newFilesystemError(ErrCodeDiskSpace, nil) + } + if growth := end - previousSize; growth > 0 { + if err := f.fs.reserveDisk(growth); err != nil { + return 0, err + } + } + + n, err := write() + writtenEnd := previousSize + if n > 0 { + writtenEnd, _ = quotaWriteEnd(off, n) + if writtenEnd > previousSize { + f.size = writtenEnd + } + } + + if reserved := end - previousSize; reserved > 0 { + actual := int64(0) + if writtenEnd > previousSize { + actual = writtenEnd - previousSize + } + if actual < reserved { + f.fs.adjustDisk(actual - reserved) + } + } + + return n, err +} + +func quotaWriteEnd(off int64, size int) (int64, bool) { + if size < 0 || off > math.MaxInt64-int64(size) { + return 0, false + } + return off + int64(size), true +} + +func (f *quotaFile) ReadFrom(r io.Reader) (int64, error) { + return io.Copy(quotaFileWriter{file: f}, r) +} + +func (f *quotaFile) Close() error { + f.mu.Lock() + defer f.mu.Unlock() + + st, statErr := f.File.Stat() + closeErr := f.File.Close() + if statErr == nil { + f.fs.adjustDisk(st.Size() - f.size) + } + if statErr != nil { + return statErr + } + return closeErr +} + +type quotaFileWriter struct { + file *quotaFile +} + +func (w quotaFileWriter) Write(p []byte) (int, error) { + return w.file.Write(p) +} diff --git a/server/install.go b/server/install.go index c98ba202f..a7e603df9 100644 --- a/server/install.go +++ b/server/install.go @@ -143,12 +143,22 @@ func (s *Server) IsInstalling() bool { return s.installing.Load() } +func (s *Server) SetInstalling(state bool) { + s.installing.Store(state) + if state { + s.Sftp().CancelAll() + } +} + func (s *Server) IsTransferring() bool { return s.transferring.Load() } func (s *Server) SetTransferring(state bool) { s.transferring.Store(state) + if state { + s.Sftp().CancelAll() + } } func (s *Server) IsRestoring() bool { @@ -157,6 +167,13 @@ func (s *Server) IsRestoring() bool { func (s *Server) SetRestoring(state bool) { s.restoring.Store(state) + if state { + s.Sftp().CancelAll() + } +} + +func (s *Server) IsInProtectedState() bool { + return s.IsInstalling() || s.IsTransferring() || s.IsRestoring() } func (s *Server) IsBackingUp() bool { @@ -188,6 +205,7 @@ func (ip *InstallationProcess) Run() error { if !ip.Server.installing.SwapIf(true) { return errors.New("install: cannot obtain installation lock") } + ip.Server.Sftp().CancelAll() // We now have an exclusive lock on this installation process. Ensure that whenever this // process is finished that the semaphore is released so that other processes and be executed @@ -242,16 +260,9 @@ func (ip *InstallationProcess) writeScriptToDisk() error { // Pulls the docker image to be used for the installation container. func (ip *InstallationProcess) pullInstallationImage() error { - // Get a registry auth configuration from the config. - var registryAuth *config.RegistryConfiguration - for registry, c := range config.Get().Docker.Registries { - if !strings.HasPrefix(ip.Script.ContainerImage, registry) { - continue - } - + registry, registryAuth := config.Get().Docker.RegistryCredentialsForImage(ip.Script.ContainerImage) + if registryAuth != nil { log.WithField("registry", registry).Debug("using authentication for registry") - registryAuth = &c - break } // Get the ImagePullOptions. diff --git a/server/server.go b/server/server.go index 7acc66fa4..57b019a55 100644 --- a/server/server.go +++ b/server/server.go @@ -485,7 +485,7 @@ func (s *Server) cleanupBackupFiles() error { var failedRemovals []string // Common backup file extensions - backupExtensions := []string{".tar.gz", ".tar.zst", ".tar", ".gz", ".zst"} + backupExtensions := []string{".tar.gz", ".tar", ".gz"} // Iterate through all files and find backup files for _, file := range files { diff --git a/server/state_test.go b/server/state_test.go new file mode 100644 index 000000000..cad6e29c0 --- /dev/null +++ b/server/state_test.go @@ -0,0 +1,47 @@ +package server + +import "testing" + +func TestProtectedStateCancelsSftpSessions(t *testing.T) { + tests := []struct { + name string + set func(*Server) + }{ + { + name: "installing", + set: func(s *Server) { + s.SetInstalling(true) + }, + }, + { + name: "transferring", + set: func(s *Server) { + s.SetTransferring(true) + }, + }, + { + name: "restoring", + set: func(s *Server) { + s.SetRestoring(true) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv, err := New(nil) + if err != nil { + t.Fatal(err) + } + + ctx := srv.Sftp().Context("user") + tt.set(srv) + + select { + case <-ctx.Done(): + default: + t.Fatal("expected SFTP session context to be canceled") + } + }) + } +} diff --git a/sftp/handler.go b/sftp/handler.go index 4fae9fbdb..121ac47dd 100644 --- a/sftp/handler.go +++ b/sftp/handler.go @@ -35,6 +35,30 @@ type Handler struct { ro bool } +type quotaWriterAt struct { + io.WriterAt + server *server.Server +} + +func (w quotaWriterAt) WriteAt(p []byte, off int64) (int, error) { + if w.server != nil && w.server.IsInProtectedState() { + return 0, sftp.ErrSSHFxPermissionDenied + } + + n, err := w.WriterAt.WriteAt(p, off) + if filesystem.IsErrorCode(err, filesystem.ErrCodeDiskSpace) { + return n, ErrSSHQuotaExceeded + } + return n, err +} + +func (w quotaWriterAt) Close() error { + if c, ok := w.WriterAt.(io.Closer); ok { + return c.Close() + } + return nil +} + // NewHandler returns a new connection handler for the SFTP server. This allows a given user // to access the underlying filesystem. func NewHandler(sc *ssh.ServerConn, srv *server.Server) (*Handler, error) { @@ -99,7 +123,7 @@ func (h *Handler) Filewrite(request *sftp.Request) (io.WriterAt, error) { l := h.logger.WithField("source", request.Filepath) // If the user doesn't have enough space left on the server it should respond with an // error since we won't be letting them write this file to the disk. - if !h.fs.HasSpaceAvailable(true) { + if !h.fs.HasSpaceAvailable(false) { return nil, ErrSSHQuotaExceeded } @@ -135,7 +159,7 @@ func (h *Handler) Filewrite(request *sftp.Request) (io.WriterAt, error) { event = server.ActivitySftpCreate } h.events.MustLog(event, FileAction{Entity: request.Filepath}) - return f, nil + return quotaWriterAt{WriterAt: f, server: h.server}, nil } // Filecmd hander for basic SFTP system calls related to files, but not anything to do with reading @@ -290,7 +314,7 @@ func (h *Handler) Filelist(request *sftp.Request) (sftp.ListerAt, error) { // Determines if a user has permission to perform a specific action on the SFTP server. These // permissions are defined and returned by the Panel API. func (h *Handler) can(permission string) bool { - if h.server.IsSuspended() { + if h.server.IsSuspended() || h.server.IsInProtectedState() { return false } for _, p := range h.permissions { diff --git a/sftp/handler_test.go b/sftp/handler_test.go new file mode 100644 index 000000000..fa120c3df --- /dev/null +++ b/sftp/handler_test.go @@ -0,0 +1,140 @@ +package sftp + +import ( + "errors" + "io" + "testing" + + pkgsftp "github.com/pkg/sftp" + + "github.com/Rene-Roscher/wings/server" +) + +type writeAtFunc func([]byte, int64) (int, error) + +func (f writeAtFunc) WriteAt(p []byte, off int64) (int, error) { + return f(p, off) +} + +func TestHandlerDeniesAccessWhenServerIsInProtectedState(t *testing.T) { + tests := []struct { + name string + set func(*server.Server) + }{ + { + name: "installing", + set: func(s *server.Server) { + s.SetInstalling(true) + }, + }, + { + name: "transferring", + set: func(s *server.Server) { + s.SetTransferring(true) + }, + }, + { + name: "restoring", + set: func(s *server.Server) { + s.SetRestoring(true) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv, err := server.New(nil) + if err != nil { + t.Fatal(err) + } + tt.set(srv) + + h := Handler{ + server: srv, + permissions: []string{"*"}, + } + + if h.can(PermissionFileCreate) { + t.Fatal("expected SFTP access to be denied") + } + }) + } +} + +func TestWriterDeniesWritesWhenServerEntersProtectedState(t *testing.T) { + tests := []struct { + name string + set func(*server.Server) + }{ + { + name: "installing", + set: func(s *server.Server) { + s.SetInstalling(true) + }, + }, + { + name: "transferring", + set: func(s *server.Server) { + s.SetTransferring(true) + }, + }, + { + name: "restoring", + set: func(s *server.Server) { + s.SetRestoring(true) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv, err := server.New(nil) + if err != nil { + t.Fatal(err) + } + + var called bool + writer := quotaWriterAt{ + server: srv, + WriterAt: writeAtFunc(func(_ []byte, _ int64) (int, error) { + called = true + return 1, nil + }), + } + tt.set(srv) + + n, err := writer.WriteAt([]byte("x"), 0) + if !errors.Is(err, pkgsftp.ErrSSHFxPermissionDenied) { + t.Fatalf("expected permission denied, got %v", err) + } + if n != 0 { + t.Fatalf("expected zero bytes written, got %d", n) + } + if called { + t.Fatal("expected underlying writer not to be called") + } + }) + } +} + +func TestWriterForwardsWritesWhenServerIsAvailable(t *testing.T) { + srv, err := server.New(nil) + if err != nil { + t.Fatal(err) + } + + writer := quotaWriterAt{ + server: srv, + WriterAt: writeAtFunc(func(p []byte, _ int64) (int, error) { + return len(p), io.EOF + }), + } + + n, err := writer.WriteAt([]byte("test"), 0) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected forwarded error, got %v", err) + } + if n != 4 { + t.Fatalf("expected forwarded byte count, got %d", n) + } +} diff --git a/sftp/server.go b/sftp/server.go index 7ecade305..7209161da 100644 --- a/sftp/server.go +++ b/sftp/server.go @@ -12,8 +12,6 @@ import ( "regexp" "strconv" "strings" - "sync" - "time" "emperror.dev/errors" "github.com/apex/log" @@ -31,317 +29,6 @@ import ( // server and sending a flood of usernames. var validUsernameRegexp = regexp.MustCompile(`^(?i)(.+)\.([a-z0-9]{8})$`) -// SmartSecurityProtector - Intelligent, configurable brute force protection -type SmartSecurityProtector struct { - mu sync.RWMutex - attempts map[string][]time.Time // IP -> attempt timestamps - blockedUntil map[string]time.Time // IP -> block expiry - reputation map[string]int // IP -> reputation score (-100 to +100) - blockHistory map[string][]time.Time // IP -> previous block times for escalation - config *config.SftpSecurityConfiguration -} - -// Global smart protector -var smartProtector *SmartSecurityProtector - -// Initialize smart protector with configuration -func initSmartProtector() { - cfg := config.Get().System.Sftp.Security - smartProtector = &SmartSecurityProtector{ - attempts: make(map[string][]time.Time), - blockedUntil: make(map[string]time.Time), - reputation: make(map[string]int), - blockHistory: make(map[string][]time.Time), - config: &cfg, - } - log.WithFields(log.Fields{ - "attempts_per_minute": cfg.Thresholds.AttemptsPerMinute, - "base_block_minutes": cfg.Blocking.BaseBlockMinutes, - "escalation_factor": cfg.Blocking.EscalationFactor, - }).Info("Smart SFTP security protection initialized") -} - -// isBlocked checks if an IP is currently blocked with smart logic -func (sp *SmartSecurityProtector) isBlocked(ip string) bool { - sp.mu.RLock() - defer sp.mu.RUnlock() - - if !sp.config.Enabled { - return false // Protection disabled - } - - // Check active block - if until, exists := sp.blockedUntil[ip]; exists { - if time.Now().Before(until) { - return true // Still blocked - } - // Block expired - apply decay to reputation - if sp.config.Reputation.Enabled { - if currentScore, hasScore := sp.reputation[ip]; hasScore { - newScore := int(float64(currentScore) * sp.config.Blocking.DecayFactor) - sp.reputation[ip] = newScore - log.WithField("ip", ip).WithField("old_score", currentScore).WithField("new_score", newScore).Debug("Applied reputation decay after block expiry") - } - } - } - - // Check reputation-based blocking - if sp.config.Reputation.Enabled { - if score, exists := sp.reputation[ip]; exists && score <= sp.config.Reputation.BlockThreshold { - log.WithField("ip", ip).WithField("reputation_score", score).Info("SMART-SECURITY: IP blocked due to poor reputation") - return true - } - } - - return false -} - -// recordFailedAttempt records a failed attempt with intelligent blocking logic -func (sp *SmartSecurityProtector) recordFailedAttempt(ip string) bool { - sp.mu.Lock() - defer sp.mu.Unlock() - - if !sp.config.Enabled { - return false - } - - now := time.Now() - - // Initialize tracking for IP - if sp.attempts[ip] == nil { - sp.attempts[ip] = make([]time.Time, 0, 50) - } - if sp.blockHistory[ip] == nil { - sp.blockHistory[ip] = make([]time.Time, 0, 10) - } - - // Clean old attempts (keep relevant timeframes) - var recentAttempts []time.Time - for _, attempt := range sp.attempts[ip] { - if now.Sub(attempt) < 24*time.Hour { // Keep 24h history - recentAttempts = append(recentAttempts, attempt) - } - } - - // Add current failed attempt - recentAttempts = append(recentAttempts, now) - sp.attempts[ip] = recentAttempts - - // Update reputation - if sp.config.Reputation.Enabled { - sp.reputation[ip] += sp.config.Reputation.BadBehaviorPenalty - if sp.reputation[ip] < -100 { - sp.reputation[ip] = -100 // Cap at minimum - } - } - - // Count attempts in different timeframes - minuteCount := sp.countAttemptsInWindow(recentAttempts, time.Minute) - hourCount := sp.countAttemptsInWindow(recentAttempts, time.Hour) - dayCount := len(recentAttempts) - - // SMART BLOCKING LOGIC - return sp.evaluateBlocking(ip, minuteCount, hourCount, dayCount, now) -} - -// countAttemptsInWindow counts attempts within a time window -func (sp *SmartSecurityProtector) countAttemptsInWindow(attempts []time.Time, window time.Duration) int { - now := time.Now() - count := 0 - for _, attempt := range attempts { - if now.Sub(attempt) <= window { - count++ - } - } - return count -} - -// evaluateBlocking implements smart blocking with escalation -func (sp *SmartSecurityProtector) evaluateBlocking(ip string, minuteCount, hourCount, dayCount int, now time.Time) bool { - // Smart threshold evaluation - if minuteCount >= sp.config.Thresholds.AttemptsPerMinute { - // Calculate smart block duration based on history - blockDuration := sp.calculateSmartBlockDuration(ip, "minute", minuteCount) - sp.blockedUntil[ip] = now.Add(blockDuration) - sp.blockHistory[ip] = append(sp.blockHistory[ip], now) - - log.WithFields(log.Fields{ - "ip": ip, - "minute_attempts": minuteCount, - "block_duration": blockDuration.String(), - "reputation_score": sp.reputation[ip], - "total_blocks": len(sp.blockHistory[ip]), - }).Warn("SMART-SECURITY: IP blocked - smart escalation applied") - return true - } - - if hourCount >= sp.config.Thresholds.AttemptsPerHour { - blockDuration := sp.calculateSmartBlockDuration(ip, "hour", hourCount) - sp.blockedUntil[ip] = now.Add(blockDuration) - sp.blockHistory[ip] = append(sp.blockHistory[ip], now) - - log.WithFields(log.Fields{ - "ip": ip, - "hour_attempts": hourCount, - "block_duration": blockDuration.String(), - "reputation_score": sp.reputation[ip], - }).Error("SMART-SECURITY: IP blocked for sustained attack pattern") - return true - } - - if dayCount >= sp.config.Thresholds.AttemptsPerDay { - blockDuration := sp.calculateSmartBlockDuration(ip, "day", dayCount) - sp.blockedUntil[ip] = now.Add(blockDuration) - sp.blockHistory[ip] = append(sp.blockHistory[ip], now) - - log.WithFields(log.Fields{ - "ip": ip, - "day_attempts": dayCount, - "block_duration": blockDuration.String(), - "reputation_score": sp.reputation[ip], - }).Error("SMART-SECURITY: IP blocked for persistent attack behavior") - return true - } - - return false -} - -// calculateSmartBlockDuration calculates intelligent block duration with escalation -func (sp *SmartSecurityProtector) calculateSmartBlockDuration(ip, trigger string, attemptCount int) time.Duration { - baseDuration := time.Duration(sp.config.Blocking.BaseBlockMinutes) * time.Minute - - // Factor in previous blocks (escalation) - previousBlocks := len(sp.blockHistory[ip]) - escalationMultiplier := 1.0 - for i := 0; i < previousBlocks; i++ { - escalationMultiplier *= sp.config.Blocking.EscalationFactor - } - - // Factor in severity of current violation - severityMultiplier := 1.0 - switch trigger { - case "minute": - excessAttempts := attemptCount - sp.config.Thresholds.AttemptsPerMinute - severityMultiplier = 1.0 + (float64(excessAttempts) * 0.5) // +50% per excess attempt - case "hour": - excessAttempts := attemptCount - sp.config.Thresholds.AttemptsPerHour - severityMultiplier = 2.0 + (float64(excessAttempts) * 0.3) // Base 2x + 30% per excess - case "day": - excessAttempts := attemptCount - sp.config.Thresholds.AttemptsPerDay - severityMultiplier = 4.0 + (float64(excessAttempts) * 0.2) // Base 4x + 20% per excess - } - - // Calculate final duration - finalDuration := time.Duration(float64(baseDuration) * escalationMultiplier * severityMultiplier) - - // Cap at maximum - maxDuration := time.Duration(sp.config.Blocking.MaxBlockHours) * time.Hour - if finalDuration > maxDuration { - finalDuration = maxDuration - } - - return finalDuration -} - -// recordSuccessfulAuth records successful authentication for reputation bonus -func (sp *SmartSecurityProtector) recordSuccessfulAuth(ip string) { - if !sp.config.Enabled || !sp.config.Reputation.Enabled { - return - } - - sp.mu.Lock() - defer sp.mu.Unlock() - - // Improve reputation for successful auth - sp.reputation[ip] += sp.config.Reputation.GoodBehaviorBonus - if sp.reputation[ip] > 100 { - sp.reputation[ip] = 100 // Cap at maximum - } - - log.WithField("ip", ip).WithField("new_reputation", sp.reputation[ip]).Debug("Reputation improved for successful authentication") -} - -// smartCleanup removes old entries with intelligent retention -func (sp *SmartSecurityProtector) smartCleanup() { - sp.mu.Lock() - defer sp.mu.Unlock() - - now := time.Now() - memoryWindow := time.Duration(sp.config.Reputation.MemoryDays) * 24 * time.Hour - - // Clean old attempts (keep reputation memory window) - for ip, attempts := range sp.attempts { - var keep []time.Time - for _, attempt := range attempts { - if now.Sub(attempt) < memoryWindow { - keep = append(keep, attempt) - } - } - if len(keep) == 0 { - delete(sp.attempts, ip) - // Also clean reputation if no recent activity - if _, hasReputation := sp.reputation[ip]; hasReputation { - log.WithField("ip", ip).Debug("Cleared reputation for inactive IP") - delete(sp.reputation, ip) - } - } else { - sp.attempts[ip] = keep - } - } - - // Clean old block history - for ip, blocks := range sp.blockHistory { - var keep []time.Time - for _, block := range blocks { - if now.Sub(block) < memoryWindow { - keep = append(keep, block) - } - } - if len(keep) == 0 { - delete(sp.blockHistory, ip) - } else { - sp.blockHistory[ip] = keep - } - } - - // Clean expired blocks and apply reputation decay - for ip, until := range sp.blockedUntil { - if now.After(until) { - log.WithField("ip", ip).WithField("reputation", sp.reputation[ip]).Info("SMART-SECURITY: IP unblocked - reputation decay applied") - delete(sp.blockedUntil, ip) - } - } - - // Log cleanup stats - totalTracked := len(sp.attempts) - totalBlocked := len(sp.blockedUntil) - if totalTracked > 0 || totalBlocked > 0 { - log.WithFields(log.Fields{ - "tracked_ips": totalTracked, - "blocked_ips": totalBlocked, - "memory_window": memoryWindow.String(), - }).Debug("Smart security cleanup completed") - } -} - -// Initialize smart protection with cleanup routine -func init() { - go func() { - // Wait for config to be loaded - time.Sleep(1 * time.Second) - initSmartProtector() - - // Start cleanup routine - ticker := time.NewTicker(15 * time.Minute) // More frequent cleanup - defer ticker.Stop() - for range ticker.C { - if smartProtector != nil { - smartProtector.smartCleanup() - } - } - }() -} - //goland:noinspection GoNameStartsWithPackageName type SFTPServer struct { manager *server.Manager @@ -419,35 +106,8 @@ func (c *SFTPServer) Run() error { if conn, _ := listener.Accept(); conn != nil { go func(conn net.Conn) { defer conn.Close() - - // CRITICAL: Extract client IP for brute force protection - clientAddr := conn.RemoteAddr().String() - clientIP := clientAddr - if host, _, err := net.SplitHostPort(clientAddr); err == nil { - clientIP = host - } - - // SMART-SECURITY: Check if IP is blocked before processing - if smartProtector != nil && smartProtector.isBlocked(clientIP) { - log.WithField("ip", clientIP).Warn("SMART-SECURITY: Rejecting connection from blocked IP") - return // Drop connection immediately - } - if err := c.AcceptInbound(conn, conf); err != nil { - // SMART-SECURITY: Handle authentication results - if smartProtector != nil { - if _, isInvalidCreds := err.(*remote.SftpInvalidCredentialsError); isInvalidCreds { - isBlocked := smartProtector.recordFailedAttempt(clientIP) - if isBlocked { - log.WithField("ip", clientIP).Error("SMART-SECURITY: IP blocked using intelligent escalation") - } - } else { - // Record successful auth for reputation bonus - smartProtector.recordSuccessfulAuth(clientIP) - } - } - - log.WithField("error", err).WithField("ip", clientAddr).Error("sftp: failed to accept inbound connection") + log.WithField("error", err).WithField("ip", conn.RemoteAddr().String()).Error("sftp: failed to accept inbound connection") } }(conn) } @@ -563,22 +223,7 @@ func (c *SFTPServer) makeCredentialsRequest(conn ssh.ConnMetadata, t remote.Sftp logger := log.WithFields(log.Fields{"subsystem": "sftp", "method": request.Type, "username": request.User, "ip": request.IP}) logger.Debug("validating credentials for SFTP connection") - // SECURITY: Enhanced username validation with suspicious pattern detection if !validUsernameRegexp.MatchString(request.User) { - // Check for common attack patterns - suspiciousPatterns := []string{"root", "admin", "administrator", "user", "test", "guest", "ftp", "ssh"} - for _, pattern := range suspiciousPatterns { - if strings.EqualFold(request.User, pattern) { - logger.WithField("attack_pattern", "common_username").Warn("SECURITY: Brute force attack detected - common username attempted") - break - } - } - - // Log suspicious usernames for monitoring - if len(request.User) < 3 || len(request.User) > 50 { - logger.WithField("attack_pattern", "unusual_length").Warn("SECURITY: Suspicious username length detected") - } - logger.Warn("failed to validate user credentials (invalid format)") return nil, &remote.SftpInvalidCredentialsError{} } diff --git a/sftp/utils.go b/sftp/utils.go index 88295016c..343bb0e52 100644 --- a/sftp/utils.go +++ b/sftp/utils.go @@ -3,6 +3,7 @@ package sftp import ( "io" "os" + "reflect" ) const ( @@ -30,6 +31,23 @@ func (l ListerAt) ListAt(f []os.FileInfo, offset int64) (int, error) { type fxErr uint32 +func (e fxErr) As(target interface{}) bool { + // pkg/sftp checks errors against its private fxerr type before writing status packets. + v := reflect.ValueOf(target) + if v.Kind() != reflect.Ptr || v.IsNil() { + return false + } + + elem := v.Elem() + t := elem.Type() + if elem.Kind() != reflect.Uint32 || t.PkgPath() != "github.com/pkg/sftp" || t.Name() != "fxerr" { + return false + } + + elem.SetUint(uint64(e)) + return true +} + func (e fxErr) Error() string { switch e { case ErrSSHQuotaExceeded: