FractalSig is a generative model designed to replicate the extreme roughness of financial volatility (
Current SOTA models (like SigDiffusions) often rely on polynomial or Fourier inversion of log-signatures. When modeling rough paths (
We prove that while smooth functions live in Sobolev spaces, rough volatility paths with (
FractalSig replaces the standard inversion layer with a Fractal Decoder: a neural network that predicts wavelet coefficients (Detail
Figure 1: Comparison of Fourier Reconstruction (suffering from Gibbs) vs. Wavelet Reconstruction (capturing true roughness).
Our hybrid "Soft-Fork" architecture leverages the strengths of both JAX (for high-performance algebra) and PyTorch (for deep learning):
graph TD
%% Input Layer
Input([Input: Brownian Motion])
subgraph JAX_Subsystem [High-Performance Compute - JAX]
A[Diffusion Model Engine]
B[Log-Signature Transform]
end
subgraph DL_Subsystem [Neural Inference - PyTorch]
C[Fractal Decoder Network]
D[Wavelet Coefficient Predictor]
end
%% Output Layer
Output([Output: Rough Volatility Path])
%% Data Flow and Labels
Input --> A
A --> B
B -- "Latent Representation" --> C
C --> D
D -- "Inverse DWT" --> Output
%% Professional Styling (Borders & Clean Colors)
style JAX_Subsystem fill:#f8fafc,stroke:#64748b,stroke-width:2px
style DL_Subsystem fill:#f8fafc,stroke:#64748b,stroke-width:2px
style Input fill:#1e293b,color:#fff,stroke-width:0px
style Output fill:#059669,color:#fff,stroke-width:0px
style A fill:#3b82f6,color:#fff,stroke-width:0px
style B fill:#ebf2ff,stroke:#3b82f6,stroke-width:2px,color:#1e3a8a
style C fill:#6366f1,color:#fff,stroke-width:0px
style D fill:#eef2ff,stroke:#6366f1,stroke-width:2px,color:#312e81
- JAX Diffusion: Learns the geometry of the path in the Log-Signature space.
- PyTorch Decoder: A "Supervised Hallucination" module that translates geometric signatures into microscopic roughness.
Deep Dive in the architecture
Our architecture separates the generative process into two distinct mathematical regimes, utilizing the best framework for each task:1. The JAX Engine: Global Geometry & Algebra
The Diffusion Model (built in JAX/Diffrax) is responsible for learning the Macro-Structure of financial paths.
-
Why JAX? Calculating Signatures involves complex tensor algebra and recursive integrals. JAX's
vmapandjitcapabilities allow us to compute these geometric invariants orders of magnitude faster than standard eager execution. -
The Output: It generates a Log-Signature (
$\mathbf{s} \in \mathbb{R}^d$ ). This vector acts as a "smooth summary" of the path, capturing:- Drift & Convexity: The fundamental trend.
- Area Integrals: The "loopiness" or interaction between dimensions.
-
Note: It lacks high-frequency information due to truncation depth
$N$ .
Mathematical Bridge: The signature provides a coordinate system for the space of paths, serving as the sufficient statistic for the path's geometry.
2. The PyTorch Decoder: Local Regularity & Texture
The Fractal Decoder (built in PyTorch) is responsible for "Supervised Hallucination" of the Micro-Structure.
-
The Challenge: A truncated signature is mathematically smooth (
$C^\infty$ ). To recover financial roughness ($H \approx 0.1$ ), we must inject energy back into the high frequencies. -
The Solution: Instead of predicting raw points
$X_t$ , the network predicts Wavelet Coefficients in a Besov Space$B^s_{p,q}$ .
The Workflow:
-
Latent Projection: It takes the smooth geometric summary
$\mathbf{s}$ . - Fractal Expansion: It expands it into a multi-scale coefficient tree.
-
Synthesis: The Inverse Discrete Wavelet Transform (IDWT) converts these coefficients into a physical path:
$$X_t = \sum_{j,k} c_{j,k} \psi_{j,k}(t)$$
Why it works: The network learns that a specific geometric configuration (e.g., a "sharp downturn" in the signature space) correlates with a specific burst of high-frequency coefficients, effectively restoring the "rough" texture that Fourier methods typically delete.
FractalSig is designed for High-Performance Hybrid Computing. The environment requires careful coordination between JAX and PyTorch for CUDA acceleration.
The most robust way to install is using the provided setup script, which handles the complex iisignature compilation and JAX CUDA dependencies automatically:
# Clone the repository
git clone https://github.com/javierdejesusda/FractalSig.git
cd FractalSig
# Run the master setup script
chmod +x setup_wsl.sh
./setup_wsl.sh
# Activate
conda activate fractalsigIf you prefer manual control, follow these steps:
-
Create Conda Environment:
conda env create -f environment.yaml conda activate fractalsig
-
Install iisignature: You must use
--no-build-isolationto ensure the C++ extensions compile against your environment's numpy headers.pip install iisignature==0.24 --no-build-isolation
-
Install JAX [CUDA]:
pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -
Install Remaining Dependencies:
pip install -r requirements.txt
FractalSig features a Unified CLI for seamless orchestration.
Ideal for testing and development on consumer hardware (e.g., RTX 4070).
python main.py +profile=laptop mode=autoOptimized for A100/V100 clusters with full dataset generation.
python main.py +profile=cluster mode=autoYou can run individual steps of the pipeline using the mode argument:
| Mode | Description |
|---|---|
auto |
Recommended. Runs the full pipeline intelligently, skipping steps if artifacts exist. |
gen_data |
Generates synthetic Ground Truth fBM paths ( |
train_decoder |
Trains the PyTorch Fractal Decoder (Signature -> Wavelet). |
train_jax |
Trains the JAX Signature Diffusion model. |
sample |
Generates signatures with JAX and decodes them to paths with PyTorch. |
fractalsig.train_decoder.train(...) and the underlying FractalDecoder now expose:
| Option | Values | Effect |
|---|---|---|
val_frac |
float in (0, 1) (default 0.2) |
Held-out validation fraction; best checkpoint is selected by val loss. |
patience |
int (default 20) |
Early-stop after this many epochs without val-loss improvement. |
loss |
"mse" | "scale_weighted" |
Plain MSE or per-scale-weighted MSE that emphasizes high-frequency wavelet bands (fractalsig/losses.py). |
loss_beta |
float (default 1.0) |
Strength of high-frequency emphasis when loss="scale_weighted". |
arch (decoder) |
"mlp" | "mlp_attn" | "transformer" |
Pluggable backbone for ablations; transformer variant is sized to stay within ~4x the MLP parameter budget. |
All datasets are exposed through the DATASETS registry in fractalsig/registries.py and follow a common SignalDataset interface (windowed slices + train/val/test split):
| Name | Domain | Source | Module |
|---|---|---|---|
synthetic_fbm |
Reference rough paths | Davies–Harte fBM with configurable |
fractalsig/datasets/synthetic_fbm.py |
sp500_intraday |
Rough volatility surrogate | Rough Bergomi simulator (Bayer/Friz/Gatheral 2016) | fractalsig/datasets/sp500_intraday.py |
turbulence_burgers |
Multifractal turbulence | Stochastic Burgers (multifractal |
fractalsig/datasets/turbulence_burgers.py |
eeg_chbmit |
Biomedical | CHB-MIT scalp-EEG single-channel windows | fractalsig/datasets/eeg_chbmit.py |
audio_esc50 |
Environmental audio | ESC-50 mono at 8 kHz | fractalsig/datasets/audio_esc50.py |
The
sp500_intradayslot is currently a rough-Bergomi simulator rather than empirical SPX intraday data — yfinance/Stooq/FRED were unreachable from the target environment. Documented in the module docstring; full empirical ingestion remains future work.
Build all caches in one shot with:
python scripts/download_datasets.pyruff check fractalsig tests
mypy fractalsig --ignore-missing-imports
pytest -m smoke -qThe smoke suite (pytest -m smoke) runs in ~3 s and is the gate enforced by .github/workflows/ci.yml on every push/PR to main.
Each method will be evaluated under a single fixed protocol — 9 methods × 5 datasets × 3 seeds — and scored on:
- Roughness recovery: increment-std ratio, DFA Hurst, wavelet Hurst, PSD slope error
- Distributional fidelity: multi-bandwidth MMD on increments, 1-Wasserstein on increments
- Generative quality: discriminative score (TCN classifier), predictive score (LSTM forecaster)
Per-cell results land in results/master_table.csv with bootstrap 95% CIs and Wilcoxon paired tests vs FractalSig.
fractalsig/
├── fractalsig/ # Core PyTorch library
│ ├── decoder.py # FractalDecoder + pluggable backbones
│ ├── losses.py # ScaleWeightedMSE
│ ├── train_decoder.py # train/val + early stopping
│ ├── seeding.py # determinism helper
│ ├── registries.py # DATASETS / BASELINES / METRICS
│ ├── data_gen.py # legacy fBM generator
│ ├── datasets/ # 5 registered datasets
│ └── runners/ # train_runner + sweep_runner
├── tests/ # pytest smoke + integration suite
├── SigDiffusions/ # JAX submodule for Signature Diffusion
├── conf/ # Hydra configuration (profile/{laptop,cluster})
├── data/ # Generated datasets (.npy / cached)
├── results/ # Plots and generated visualizations
├── notebooks/ # Jupyter notebooks for analysis
├── scripts/ # download_datasets.py, generate_figure.py, ...
├── docs/ # triage report and design notes
├── .github/workflows/ci.yml # Lint + type + smoke tests
├── pyproject.toml # ruff/mypy/pytest config; requires Python >=3.11,<3.14
└── main.py # Unified CLI Entry Point (legacy auto/gen_data/train_decoder/train_jax/sample)
Distributed under the MIT License. See LICENSE for more information.
