diff --git a/.gitignore b/.gitignore
index ae99fd0..482b400 100644
--- a/.gitignore
+++ b/.gitignore
@@ -70,6 +70,7 @@ instance/
# Sphinx documentation
docs/_build/
+docs/source/api/
# PyBuilder
target/
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 157ea9f..b1f3539 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -112,5 +112,5 @@
# Autoapi
autoapi_dirs = ["../../src"]
autoapi_root = "api"
-autoapi_keep_files = False
+autoapi_keep_files = True
autodoc_typehints = "description"
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 1bd27fa..03a292b 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -2,7 +2,7 @@
:sd_hide_title:
`tdhook`
-=====
+========
.. toctree::
:maxdepth: 1
diff --git a/docs/source/methods.rst b/docs/source/methods.rst
index 4f59a42..7e800ad 100644
--- a/docs/source/methods.rst
+++ b/docs/source/methods.rst
@@ -105,6 +105,23 @@ Methods
+ .. grid-item-card::
+ :link: notebooks/methods/representation-similarity.ipynb
+ :class-card: surface
+ :class-body: surface
+
+ .. raw:: html
+
+
+
+
+
+
+
Representation Similarity
+
Compare latent representations with CKA and leave room for additional similarity metrics.
+
+
+
.. toctree::
:hidden:
:maxdepth: 2
@@ -114,3 +131,4 @@ Methods
notebooks/methods/linear-probing.ipynb
notebooks/methods/bilinear-probing.ipynb
notebooks/methods/dimension-estimation.ipynb
+ notebooks/methods/representation-similarity.ipynb
diff --git a/docs/source/notebooks/methods/representation-similarity.ipynb b/docs/source/notebooks/methods/representation-similarity.ipynb
new file mode 100644
index 0000000..eb60f09
--- /dev/null
+++ b/docs/source/notebooks/methods/representation-similarity.ipynb
@@ -0,0 +1,197 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Representation Similarity\n",
+ "\n",
+ "This notebook introduces representation similarity methods in `tdhook`.\n",
+ "\n",
+ "It currently starts with centered kernel alignment (CKA) through `tdhook.latent.representation_similarity.CkaEstimator`. More similarity methods can be added here later as the module grows."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import importlib.util\n",
+ "\n",
+ "DEV = True\n",
+ "\n",
+ "if importlib.util.find_spec(\"google.colab\") is not None:\n",
+ " MODE = \"colab-dev\" if DEV else \"colab\"\n",
+ "else:\n",
+ " MODE = \"local\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if MODE == \"colab\":\n",
+ " %pip install -q tdhook\n",
+ "elif MODE == \"colab-dev\":\n",
+ " !rm -rf tdhook\n",
+ " !git clone https://github.com/Xmaster6y/tdhook -b main\n",
+ " %pip install -q ./tdhook"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from tensordict import TensorDict\n",
+ "\n",
+ "from tdhook.latent.representation_similarity import CkaEstimator"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Synthetic Example\n",
+ "\n",
+ "We build a few pairs of representations with known relationships:\n",
+ "\n",
+ "- `same`: identical representations, so CKA should be close to `1`\n",
+ "- `rotated`: an orthogonal transform of the same representation, which linear CKA should also score near `1`\n",
+ "- `random`: an unrelated representation, which should typically score much lower"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "torch.manual_seed(0)\n",
+ "\n",
+ "x = torch.randn(256, 32)\n",
+ "q, _ = torch.linalg.qr(torch.randn(32, 32))\n",
+ "\n",
+ "examples = {\n",
+ " \"same\": (x, x.clone()),\n",
+ " \"rotated\": (x, x @ q),\n",
+ " \"random\": (x, torch.randn(256, 24)),\n",
+ "}\n",
+ "\n",
+ "estimator = CkaEstimator(kernel=\"linear\")\n",
+ "\n",
+ "\n",
+ "def run_cka(x, y):\n",
+ " td = TensorDict({\"data_a\": x, \"data_b\": y}, batch_size=[])\n",
+ " return estimator(td.clone())[\"cka\"].item()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'same': 1.0, 'rotated': 1.0, 'random': 0.10095701366662979}"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "scores = {name: run_cka(a, b) for name, (a, b) in examples.items()}\n",
+ "scores"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Batched Inputs\n",
+ "\n",
+ "Like the dimension-estimation modules, `CkaEstimator` accepts either `(N, D)` or batched `(..., N, D)` inputs and returns one scalar score per batch item."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.9979, 0.9978, 0.9976])"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batched_x = torch.randn(3, 128, 16)\n",
+ "batched_y = batched_x + 0.05 * torch.randn(3, 128, 16)\n",
+ "\n",
+ "td = TensorDict({\"data_a\": batched_x, \"data_b\": batched_y}, batch_size=[3])\n",
+ "CkaEstimator()(td.clone())[\"cka\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## API Notes\n",
+ "\n",
+ "- The estimator is named `CkaEstimator` and already exposes a `kernel` argument.\n",
+ "- At the moment only `kernel=\"linear\"` is implemented.\n",
+ "- Degenerate inputs with zero variance return `nan` instead of raising.\n",
+ "\n",
+ "Future methods can extend this notebook with additional sections, comparisons, and visualizations."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/tdhook/attribution/gradient_helpers/helpers.py b/src/tdhook/attribution/gradient_helpers/helpers.py
index 85b18f1..101218f 100644
--- a/src/tdhook/attribution/gradient_helpers/helpers.py
+++ b/src/tdhook/attribution/gradient_helpers/helpers.py
@@ -13,6 +13,8 @@
class Riemann(Enum):
+ """Supported Riemann integration variants for gradient approximation."""
+
left = 1
right = 2
middle = 3
@@ -42,7 +44,7 @@ def approximation_parameters(
return riemann_builders(method=Riemann[method.split("_")[-1]])
if method == "gausslegendre":
return gauss_legendre_builders()
- raise ValueError("Invalid integral approximation method name: {}".format(method))
+ raise ValueError(f"Invalid integral approximation method name: {method}")
def riemann_builders(
diff --git a/src/tdhook/attribution/lrp_helpers/rules.py b/src/tdhook/attribution/lrp_helpers/rules.py
index d75b53e..f53fd2c 100644
--- a/src/tdhook/attribution/lrp_helpers/rules.py
+++ b/src/tdhook/attribution/lrp_helpers/rules.py
@@ -95,6 +95,13 @@ class AbstractFunctionMeta(ABCMeta, FunctionMeta):
class Rule(Function, metaclass=AbstractFunctionMeta):
+ """Base class for LRP rules implemented as custom autograd functions.
+
+ Subclasses override `forward()` and `backward()` to define how relevance is
+ propagated through a wrapped module. Instances are registered onto modules
+ by temporarily replacing `module.forward` with `Function.apply(...)`.
+ """
+
def __init__(self):
self._apply_kwargs = {}
# TODO: Add zero_params argument for all rules
@@ -116,11 +123,13 @@ def unregister(self, module: nn.Module):
@staticmethod
@abstractmethod
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
+ """Run the wrapped module and save any tensors needed for relevance propagation."""
pass
@staticmethod
@abstractmethod
def backward(ctx, *out_relevance):
+ """Propagate output relevance back to the inputs of the wrapped module."""
pass
@@ -454,9 +463,7 @@ def _call(self, name: str, module: nn.Module) -> Rule | None:
def __call__(self, name: str, module: nn.Module) -> Rule | None:
rule = self._rule_mapper(name, module)
- if rule is None:
- return self._call(name, module)
- return rule
+ return self._call(name, module) if rule is None else rule
class EpsilonPlus(BaseRuleMapper):
diff --git a/src/tdhook/latent/__init__.py b/src/tdhook/latent/__init__.py
index e4e2285..65b3313 100644
--- a/src/tdhook/latent/__init__.py
+++ b/src/tdhook/latent/__init__.py
@@ -15,6 +15,7 @@
Probe,
ProbeManager,
)
+from .representation_similarity import CkaEstimator
from .steering_vectors import SteeringVectors, ActivationAddition
__all__ = [
@@ -23,6 +24,7 @@
"ActivationPatching",
"BilinearProbe",
"BilinearProbeManager",
+ "CkaEstimator",
"LinearEstimator",
"LowRankBilinearEstimator",
"LocalKnnDimensionEstimator",
diff --git a/src/tdhook/latent/representation_similarity.py b/src/tdhook/latent/representation_similarity.py
deleted file mode 100644
index b9f8044..0000000
--- a/src/tdhook/latent/representation_similarity.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# CKA: Similarity of Neural Network Representations Revisited
-# Mutual k-Nearest Neighbor Alignment Metric: The Platonic Representation Hypothesis
-# Information Imbalance: A quantitative analysis of semantic information in deep representations of text and images
diff --git a/src/tdhook/latent/representation_similarity/__init__.py b/src/tdhook/latent/representation_similarity/__init__.py
new file mode 100644
index 0000000..f0c18e3
--- /dev/null
+++ b/src/tdhook/latent/representation_similarity/__init__.py
@@ -0,0 +1,7 @@
+"""
+Representation similarity methods.
+"""
+
+from .cka import CkaEstimator
+
+__all__ = ["CkaEstimator"]
diff --git a/src/tdhook/latent/representation_similarity/cka.py b/src/tdhook/latent/representation_similarity/cka.py
new file mode 100644
index 0000000..3991fc5
--- /dev/null
+++ b/src/tdhook/latent/representation_similarity/cka.py
@@ -0,0 +1,95 @@
+from textwrap import indent
+
+import torch
+from tensordict import TensorDict
+from tensordict.nn import TensorDictModuleBase
+
+
+class CkaEstimator(TensorDictModuleBase):
+ """
+ Centered kernel alignment (CKA) between two representations.
+
+ Reads two data tensors from the input TensorDict. Expects `(N, D)` or
+ `(..., N, D)` for both tensors, with shared batch shape and sample count.
+ Outputs one scalar similarity value per batch item.
+ """
+
+ def __init__(
+ self,
+ in_key_a: str = "data_a",
+ in_key_b: str = "data_b",
+ out_key: str = "cka",
+ kernel: str = "linear",
+ eps: float = 1e-12,
+ ):
+ super().__init__()
+ if kernel != "linear":
+ raise NotImplementedError(f"Unsupported kernel '{kernel}'. Only 'linear' is implemented for now.")
+ self.in_key_a = in_key_a
+ self.in_key_b = in_key_b
+ self.out_key = out_key
+ self.kernel = kernel
+ self.eps = eps
+ self.in_keys = [in_key_a, in_key_b]
+ self.out_keys = [out_key]
+
+ def forward(self, td: TensorDict) -> TensorDict:
+ x = td[self.in_key_a]
+ y = td[self.in_key_b]
+ _validate_inputs(x, y)
+
+ batch_shape = x.shape[:-2]
+ n = x.shape[-2]
+ flat_x = x.reshape(-1, n, x.shape[-1])
+ flat_y = y.reshape(-1, n, y.shape[-1])
+ if flat_x.shape[0] == 0:
+ td[self.out_key] = torch.empty(batch_shape, dtype=torch.float32, device=x.device)
+ return td
+ cka_values = [_linear_cka(flat_x[i], flat_y[i], eps=self.eps) for i in range(flat_x.shape[0])]
+ td[self.out_key] = torch.stack(cka_values).reshape(batch_shape)
+ return td
+
+ def __repr__(self):
+ fields = indent(
+ f"in_keys={self.in_keys},\nout_keys={self.out_keys},\nkernel='{self.kernel}',\neps={self.eps}",
+ 4 * " ",
+ )
+ return f"{type(self).__name__}(\n{fields})"
+
+
+def _validate_inputs(x: torch.Tensor, y: torch.Tensor) -> None:
+ if x.ndim < 2 or y.ndim < 2:
+ raise ValueError("CKA expects tensors with shape (N, D) or (..., N, D)")
+ if x.shape[:-2] != y.shape[:-2]:
+ raise ValueError(f"Expected matching batch shapes, got {x.shape[:-2]} and {y.shape[:-2]}")
+ if x.shape[-2] != y.shape[-2]:
+ raise ValueError(f"Expected matching sample counts, got {x.shape[-2]} and {y.shape[-2]}")
+ if x.device != y.device:
+ raise ValueError(f"Expected both tensors on the same device, got {x.device} and {y.device}")
+
+
+def _linear_cka(x: torch.Tensor, y: torch.Tensor, eps: float) -> torch.Tensor:
+ dtype = torch.promote_types(x.dtype, y.dtype)
+ if not torch.empty((), dtype=dtype).is_floating_point():
+ dtype = torch.float32
+
+ x = x.to(dtype=dtype)
+ y = y.to(dtype=dtype)
+ x = x - x.mean(dim=0, keepdim=True)
+ y = y - y.mean(dim=0, keepdim=True)
+
+ cross_cov = x.transpose(-1, -2) @ y
+ x_cov = x.transpose(-1, -2) @ x
+ y_cov = y.transpose(-1, -2) @ y
+
+ numerator = torch.sum(cross_cov.square())
+ x_norm = torch.sum(x_cov.square())
+ y_norm = torch.sum(y_cov.square())
+ denominator = torch.sqrt(x_norm * y_norm)
+
+ nan = torch.full((), float("nan"), dtype=dtype, device=x.device)
+ if not torch.isfinite(denominator) or denominator <= eps:
+ return nan
+
+ value = numerator / denominator
+ return value.float() if torch.isfinite(value) else nan
diff --git a/tests/latent/test_representation_similarity.py b/tests/latent/test_representation_similarity.py
new file mode 100644
index 0000000..ec1bbb8
--- /dev/null
+++ b/tests/latent/test_representation_similarity.py
@@ -0,0 +1,171 @@
+"""
+Tests for representation similarity estimators.
+"""
+
+import pytest
+import torch
+from tensordict import TensorDict
+
+from tdhook.latent.representation_similarity import CkaEstimator
+
+
+def make_td(x, y, in_key_a="data_a", in_key_b="data_b", batch_size=None):
+ if batch_size is None:
+ batch_size = [] if x.ndim == 2 else x.shape[:-2]
+ return TensorDict({in_key_a: x, in_key_b: y}, batch_size=batch_size)
+
+
+def make_random_pair(n=64, d_x=10, d_y=7):
+ return torch.randn(n, d_x), torch.randn(n, d_y)
+
+
+@pytest.fixture
+def run_estimator():
+ torch.manual_seed(42)
+
+ def _run(x, y, in_key_a="data_a", in_key_b="data_b", batch_size=None, **estimator_kwargs):
+ td = make_td(x, y, in_key_a=in_key_a, in_key_b=in_key_b, batch_size=batch_size)
+ return CkaEstimator(in_key_a=in_key_a, in_key_b=in_key_b, **estimator_kwargs)(td)
+
+ return _run
+
+
+class TestCkaEstimator:
+ def test_default_keys(self, run_estimator):
+ x, y = make_random_pair()
+
+ result = run_estimator(x, y)
+
+ assert "cka" in result
+ assert result["cka"].ndim == 0
+ assert result["cka"].dtype in (torch.float32, torch.float64)
+ assert torch.isfinite(result["cka"])
+
+ def test_custom_keys(self, run_estimator):
+ x = torch.randn(48, 8)
+ y = torch.randn(48, 6)
+
+ result = run_estimator(x, y, in_key_a="linear1", in_key_b="linear2", out_key="similarity")
+
+ assert "linear1" in result
+ assert "linear2" in result
+ assert "similarity" in result
+ assert result["similarity"].ndim == 0
+
+ def test_identical_views_are_one(self, run_estimator):
+ x = torch.randn(128, 16)
+
+ result = run_estimator(x, x.clone())
+
+ assert torch.isclose(result["cka"], torch.tensor(1.0), atol=1e-6)
+
+ def test_invariant_to_isotropic_scaling(self, run_estimator):
+ x = torch.randn(128, 16)
+ y = 7.5 * x
+
+ result = run_estimator(x, y)
+
+ assert torch.isclose(result["cka"], torch.tensor(1.0), atol=1e-5)
+
+ def test_invariant_to_orthogonal_rotation(self, run_estimator):
+ x = torch.randn(128, 12)
+ q, _ = torch.linalg.qr(torch.randn(12, 12))
+ y = x @ q
+
+ result = run_estimator(x, y)
+
+ assert torch.isclose(result["cka"], torch.tensor(1.0), atol=1e-5)
+
+ def test_independent_random_views_have_low_cka(self, run_estimator):
+ x = torch.randn(512, 32)
+ y = torch.randn(512, 24)
+
+ result = run_estimator(x, y)
+
+ assert result["cka"].item() < 0.2
+
+ @pytest.mark.parametrize(
+ ("x_shape", "y_shape"),
+ [
+ ((1, 10, 8), (1, 10, 6)),
+ ((5, 10, 8), (5, 10, 6)),
+ ((2, 3, 10, 8), (2, 3, 10, 6)),
+ ],
+ ids=["1x10", "5x10", "2x3x10"],
+ )
+ def test_batch_shape_preservation(self, run_estimator, x_shape, y_shape):
+ x = torch.randn(*x_shape)
+ y = torch.randn(*y_shape)
+ batch_size = x_shape[:-2]
+
+ result = run_estimator(x, y, batch_size=batch_size)
+
+ assert result["cka"].shape == batch_size
+
+ def test_empty_flattened_batch_returns_empty_output(self, run_estimator):
+ x = torch.randn(2, 0, 10, 8)
+ y = torch.randn(2, 0, 10, 6)
+
+ result = run_estimator(x, y, batch_size=[2, 0])
+
+ assert result["cka"].shape == (2, 0)
+ assert result["cka"].dtype == torch.float32
+ assert result["cka"].numel() == 0
+
+ def test_mismatched_sample_counts_raise(self, run_estimator):
+ with pytest.raises(ValueError, match="matching sample counts"):
+ run_estimator(torch.randn(32, 8), torch.randn(31, 6))
+
+ def test_mismatched_batch_shapes_raise(self, run_estimator):
+ with pytest.raises(ValueError, match="matching batch shapes"):
+ run_estimator(torch.randn(2, 3, 16, 8), torch.randn(2, 4, 16, 6), batch_size=[2])
+
+ def test_invalid_rank_raises(self, run_estimator):
+ with pytest.raises(ValueError, match=r"shape \(N, D\) or \(\.\.\., N, D\)"):
+ run_estimator(torch.randn(32), torch.randn(32))
+
+ def test_mismatched_devices_raise(self, run_estimator):
+ x = torch.randn(32, 8)
+ y = torch.randn(32, 6, device="meta")
+
+ with pytest.raises(ValueError, match="same device"):
+ run_estimator(x, y)
+
+ def test_constant_representation_returns_nan(self, run_estimator):
+ x = torch.ones(64, 8)
+ y = torch.randn(64, 6)
+
+ result = run_estimator(x, y)
+
+ assert torch.isnan(result["cka"])
+
+ def test_integer_inputs_are_promoted_to_float32(self, run_estimator):
+ x = torch.arange(512, dtype=torch.int64).reshape(128, 4)
+
+ result = run_estimator(x, x.clone())
+
+ assert result["cka"].dtype == torch.float32
+ assert torch.isclose(result["cka"], torch.tensor(1.0, dtype=torch.float32), atol=1e-6)
+
+ def test_determinism(self, run_estimator):
+ x = torch.randn(96, 10)
+ y = torch.randn(96, 7)
+
+ r1 = run_estimator(x.clone(), y.clone())["cka"]
+ r2 = run_estimator(x.clone(), y.clone())["cka"]
+
+ assert torch.allclose(r1, r2, equal_nan=True)
+
+ def test_repr(self):
+ est = CkaEstimator()
+ r = repr(est)
+
+ assert "CkaEstimator" in r
+ assert "in_keys=['data_a', 'data_b']" in r
+ assert "out_keys=['cka']" in r
+ assert "kernel='linear'" in r
+ assert "eps=" in r
+
+ def test_unknown_kernel_raises(self):
+ with pytest.raises(NotImplementedError, match="Only 'linear' is implemented"):
+ CkaEstimator(kernel="rbf")