A lightweight library for trying out training data attribution (TDA) methods on PyTorch models. Implements:
- Influence functions with EK-FAC curvature approximation (Grosse et al. 2023)
- SOURCE — segmented unrolled differentiation that handles non-converged and multi-stage training (Bae et al. 2024, arXiv:2405.12186)
- Baselines: gradient similarity, representation similarity, TracIn / GAS, and a thin wrapper around TRAK
Designed to be easy to adapt to your own model and task — typically you only need to subclass
AbstractTask (see examples/ for four worked end-to-end pipelines covering regression, image
classification, GLUE, and GPT-2 language modeling).
- Create a new Conda environment:
conda create -n simple_influence python=3.10 conda activate simple_influence
- Install the package and its dependencies (use the
pytorch_cpuextra if you don't have a GPU):PyTorchpip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'
>=2.0.0is required (the EK-FAC implementation usestorch.func/functorch). - Run the test suite to verify the install:
The non-smoke tests run on CPU against synthetic data. To also run the slower GLUE/Wiki smoke tests, use
pytest tests/
pytest -m smoke tests/.
from src.influence_function import InfluenceFunctionComputer
from examples.mnist.task import ClassificationTask
# `model` is any nn.Module; `train_loader` / `valid_loader` are standard PyTorch DataLoaders.
task = ClassificationTask(device="cuda")
computer = InfluenceFunctionComputer(model=model, task=task)
computer.build_curvature_blocks(train_loader) # one pass to build EK-FAC.
scores = computer.compute_scores_with_loader(valid_loader, train_loader)
# `scores[i, j]` is the influence of training point `j` on validation point `i`.To swap in a different attribution method, replace InfluenceFunctionComputer with
GradientSimilarityComputer, RepresentationSimilarityComputer, TracinComputer, or
TrakComputer — each implements the same compute_scores_with_loader(test_loader, train_loader)
interface.
For non-converged or multi-stage training, use SourceComputer. It needs a list of checkpoints
per training segment plus the (averaged) learning rate and total number of gradient updates per
segment:
from src.source import SourceComputer
# Suppose you saved 6 checkpoints during training and want L = 3 segments
# (early / middle / late), with 2 checkpoints per segment.
source = SourceComputer(
model=model, # holds the final parameters theta_s
task=task,
checkpoints_per_segment=[
["ckpts/epoch_2.pt", "ckpts/epoch_4.pt"], # earliest segment
["ckpts/epoch_6.pt", "ckpts/epoch_8.pt"], # middle segment
["ckpts/epoch_10.pt", "ckpts/epoch_12.pt"], # latest segment
],
iters_per_segment=[K_1, K_2, K_3], # gradient updates per segment
lrs_per_segment=[eta_1, eta_2, eta_3], # averaged learning rate per segment
)
source.build_curvature_blocks(train_loader)
scores = source.compute_scores_with_loader(valid_loader, train_loader)With L = 1 (a single segment), SOURCE reduces to an EK-FAC influence function with damping
lambda = 1 / (eta * K) — i.e. the damping is derived from the training schedule rather than
hand-tuned. Currently only Linear and Conv2d modules are supported by SourceComputer.
Example scripts use absolute imports (from examples.mnist.pipeline import ...) and also
read/write paths relative to the example directory (files/checkpoints/, files/results/).
The easiest way to run them is to cd into the example directory and point PYTHONPATH at
the project root so the imports resolve:
cd examples/mnist
PYTHONPATH=../.. python -m examples.mnist.trainCPU-only (small enough that GPU has no benefit):
cd examples/regression
PYTHONPATH=../.. python -m examples.regression.train # writes files/checkpoints/
PYTHONPATH=../.. python -m examples.regression.compute_influences # writes files/results/
PYTHONPATH=../.. python -m examples.regression.evaluate.visualize_distributioncd examples/mnist
PYTHONPATH=../.. python -m examples.mnist.train
PYTHONPATH=../.. python -m examples.mnist.compute_influences # IF + SOURCE
PYTHONPATH=../.. python -m examples.mnist.evaluate.visualize_influencesA small end-to-end smoke test (~1 minute on CPU) that trains briefly and runs both IF and SOURCE on the result:
cd examples/mnist
PYTHONPATH=../.. python smoke_test_source.pyTested on an A100 80GB; reduce batch size for smaller GPUs.
cd examples/glue
PYTHONPATH=../.. python -m examples.glue.train
PYTHONPATH=../.. python -m examples.glue.compute_influences
PYTHONPATH=../.. python -m examples.glue.evaluate.inspect_influencesTested on an A100 80GB; reduce batch size for smaller GPUs.
cd examples/wiki
PYTHONPATH=../.. python -m examples.wiki.train
PYTHONPATH=../.. python -m examples.wiki.compute_influences
PYTHONPATH=../.. python -m examples.wiki.evaluate.inspect_influences- Install the optional development dependencies:
pip install -e '.[dev]' - Install the pre-commit hooks:
pre-commit install
- Run the tests to verify that everything is functioning correctly:
pytest
This library is intentionally minimal — it's a clean reference implementation aimed at making the methods easy to read and adapt. For a production-grade implementation with query batching, layer-wise and token-wise score breakdowns, multi-GPU support, and many more features, see Kronfluence.
- EK-FAC influence calculations are only compatible with the following modules:
Linear,Conv2d,LayerNorm,BatchNorm2d, andEmbedding. Custom modules that hold parameters directly (rather than wrapping one of the supported layers) won't have EK-FAC statistics collected for them:Replace these with one of the supported modules — seeimport torch.nn as nn import torch class CustomLinear(nn.Module): def __init__(self, num_inputs, num_outputs): super().__init__() self.weight = nn.Parameter(torch.Tensor((num_inputs, num_outputs))) def forward(self, inputs): return inputs @ self.weight
replace_conv1d_modulesinexamples/wiki/pipeline.pyfor a worked example (HuggingFace'sConv1D->nn.Linear). - All TDA computers in this library are single-GPU.
- To estimate the true Fisher for EK-FAC, the task's loss function must sample targets from
the model output (see
examples/for several worked examples). SourceComputercurrently supports onlyLinearandConv2dmodules — the LayerNorm / BatchNorm "full" and Embedding "diagonal" Fisher branches don't yet have a closed-form matrix function for the SOURCE preconditioner.
- 2023/10/14: Initial implementation, supporting four examples:
regression,mnist,glue,wiki. - 2024-06: Added
SourceComputer(Bae et al. 2024); standardized docstrings; bug fixes.