diff --git a/.gitignore b/.gitignore index ae99fd0..482b400 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,7 @@ instance/ # Sphinx documentation docs/_build/ +docs/source/api/ # PyBuilder target/ diff --git a/docs/source/conf.py b/docs/source/conf.py index 157ea9f..b1f3539 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -112,5 +112,5 @@ # Autoapi autoapi_dirs = ["../../src"] autoapi_root = "api" -autoapi_keep_files = False +autoapi_keep_files = True autodoc_typehints = "description" diff --git a/docs/source/index.rst b/docs/source/index.rst index 1bd27fa..03a292b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -2,7 +2,7 @@ :sd_hide_title: `tdhook` -===== +======== .. toctree:: :maxdepth: 1 diff --git a/docs/source/methods.rst b/docs/source/methods.rst index 4f59a42..7e800ad 100644 --- a/docs/source/methods.rst +++ b/docs/source/methods.rst @@ -105,6 +105,23 @@ Methods + .. grid-item-card:: + :link: notebooks/methods/representation-similarity.ipynb + :class-card: surface + :class-body: surface + + .. raw:: html + +
+
+ +
+
+
Representation Similarity
+

Compare latent representations with CKA and leave room for additional similarity metrics.

+
+
+ .. toctree:: :hidden: :maxdepth: 2 @@ -114,3 +131,4 @@ Methods notebooks/methods/linear-probing.ipynb notebooks/methods/bilinear-probing.ipynb notebooks/methods/dimension-estimation.ipynb + notebooks/methods/representation-similarity.ipynb diff --git a/docs/source/notebooks/methods/representation-similarity.ipynb b/docs/source/notebooks/methods/representation-similarity.ipynb new file mode 100644 index 0000000..eb60f09 --- /dev/null +++ b/docs/source/notebooks/methods/representation-similarity.ipynb @@ -0,0 +1,197 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Representation Similarity\n", + "\n", + "This notebook introduces representation similarity methods in `tdhook`.\n", + "\n", + "It currently starts with centered kernel alignment (CKA) through `tdhook.latent.representation_similarity.CkaEstimator`. More similarity methods can be added here later as the module grows." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import importlib.util\n", + "\n", + "DEV = True\n", + "\n", + "if importlib.util.find_spec(\"google.colab\") is not None:\n", + " MODE = \"colab-dev\" if DEV else \"colab\"\n", + "else:\n", + " MODE = \"local\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "if MODE == \"colab\":\n", + " %pip install -q tdhook\n", + "elif MODE == \"colab-dev\":\n", + " !rm -rf tdhook\n", + " !git clone https://github.com/Xmaster6y/tdhook -b main\n", + " %pip install -q ./tdhook" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from tensordict import TensorDict\n", + "\n", + "from tdhook.latent.representation_similarity import CkaEstimator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Synthetic Example\n", + "\n", + "We build a few pairs of representations with known relationships:\n", + "\n", + "- `same`: identical representations, so CKA should be close to `1`\n", + "- `rotated`: an orthogonal transform of the same representation, which linear CKA should also score near `1`\n", + "- `random`: an unrelated representation, which should typically score much lower" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(0)\n", + "\n", + "x = torch.randn(256, 32)\n", + "q, _ = torch.linalg.qr(torch.randn(32, 32))\n", + "\n", + "examples = {\n", + " \"same\": (x, x.clone()),\n", + " \"rotated\": (x, x @ q),\n", + " \"random\": (x, torch.randn(256, 24)),\n", + "}\n", + "\n", + "estimator = CkaEstimator(kernel=\"linear\")\n", + "\n", + "\n", + "def run_cka(x, y):\n", + " td = TensorDict({\"data_a\": x, \"data_b\": y}, batch_size=[])\n", + " return estimator(td.clone())[\"cka\"].item()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'same': 1.0, 'rotated': 1.0, 'random': 0.10095701366662979}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores = {name: run_cka(a, b) for name, (a, b) in examples.items()}\n", + "scores" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Batched Inputs\n", + "\n", + "Like the dimension-estimation modules, `CkaEstimator` accepts either `(N, D)` or batched `(..., N, D)` inputs and returns one scalar score per batch item." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.9979, 0.9978, 0.9976])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batched_x = torch.randn(3, 128, 16)\n", + "batched_y = batched_x + 0.05 * torch.randn(3, 128, 16)\n", + "\n", + "td = TensorDict({\"data_a\": batched_x, \"data_b\": batched_y}, batch_size=[3])\n", + "CkaEstimator()(td.clone())[\"cka\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API Notes\n", + "\n", + "- The estimator is named `CkaEstimator` and already exposes a `kernel` argument.\n", + "- At the moment only `kernel=\"linear\"` is implemented.\n", + "- Degenerate inputs with zero variance return `nan` instead of raising.\n", + "\n", + "Future methods can extend this notebook with additional sections, comparisons, and visualizations." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/tdhook/attribution/gradient_helpers/helpers.py b/src/tdhook/attribution/gradient_helpers/helpers.py index 85b18f1..101218f 100644 --- a/src/tdhook/attribution/gradient_helpers/helpers.py +++ b/src/tdhook/attribution/gradient_helpers/helpers.py @@ -13,6 +13,8 @@ class Riemann(Enum): + """Supported Riemann integration variants for gradient approximation.""" + left = 1 right = 2 middle = 3 @@ -42,7 +44,7 @@ def approximation_parameters( return riemann_builders(method=Riemann[method.split("_")[-1]]) if method == "gausslegendre": return gauss_legendre_builders() - raise ValueError("Invalid integral approximation method name: {}".format(method)) + raise ValueError(f"Invalid integral approximation method name: {method}") def riemann_builders( diff --git a/src/tdhook/attribution/lrp_helpers/rules.py b/src/tdhook/attribution/lrp_helpers/rules.py index d75b53e..f53fd2c 100644 --- a/src/tdhook/attribution/lrp_helpers/rules.py +++ b/src/tdhook/attribution/lrp_helpers/rules.py @@ -95,6 +95,13 @@ class AbstractFunctionMeta(ABCMeta, FunctionMeta): class Rule(Function, metaclass=AbstractFunctionMeta): + """Base class for LRP rules implemented as custom autograd functions. + + Subclasses override `forward()` and `backward()` to define how relevance is + propagated through a wrapped module. Instances are registered onto modules + by temporarily replacing `module.forward` with `Function.apply(...)`. + """ + def __init__(self): self._apply_kwargs = {} # TODO: Add zero_params argument for all rules @@ -116,11 +123,13 @@ def unregister(self, module: nn.Module): @staticmethod @abstractmethod def forward(ctx, apply_kwargs, module, model_kwargs, *inputs): + """Run the wrapped module and save any tensors needed for relevance propagation.""" pass @staticmethod @abstractmethod def backward(ctx, *out_relevance): + """Propagate output relevance back to the inputs of the wrapped module.""" pass @@ -454,9 +463,7 @@ def _call(self, name: str, module: nn.Module) -> Rule | None: def __call__(self, name: str, module: nn.Module) -> Rule | None: rule = self._rule_mapper(name, module) - if rule is None: - return self._call(name, module) - return rule + return self._call(name, module) if rule is None else rule class EpsilonPlus(BaseRuleMapper): diff --git a/src/tdhook/latent/__init__.py b/src/tdhook/latent/__init__.py index e4e2285..65b3313 100644 --- a/src/tdhook/latent/__init__.py +++ b/src/tdhook/latent/__init__.py @@ -15,6 +15,7 @@ Probe, ProbeManager, ) +from .representation_similarity import CkaEstimator from .steering_vectors import SteeringVectors, ActivationAddition __all__ = [ @@ -23,6 +24,7 @@ "ActivationPatching", "BilinearProbe", "BilinearProbeManager", + "CkaEstimator", "LinearEstimator", "LowRankBilinearEstimator", "LocalKnnDimensionEstimator", diff --git a/src/tdhook/latent/representation_similarity.py b/src/tdhook/latent/representation_similarity.py deleted file mode 100644 index b9f8044..0000000 --- a/src/tdhook/latent/representation_similarity.py +++ /dev/null @@ -1,3 +0,0 @@ -# CKA: Similarity of Neural Network Representations Revisited -# Mutual k-Nearest Neighbor Alignment Metric: The Platonic Representation Hypothesis -# Information Imbalance: A quantitative analysis of semantic information in deep representations of text and images diff --git a/src/tdhook/latent/representation_similarity/__init__.py b/src/tdhook/latent/representation_similarity/__init__.py new file mode 100644 index 0000000..f0c18e3 --- /dev/null +++ b/src/tdhook/latent/representation_similarity/__init__.py @@ -0,0 +1,7 @@ +""" +Representation similarity methods. +""" + +from .cka import CkaEstimator + +__all__ = ["CkaEstimator"] diff --git a/src/tdhook/latent/representation_similarity/cka.py b/src/tdhook/latent/representation_similarity/cka.py new file mode 100644 index 0000000..3991fc5 --- /dev/null +++ b/src/tdhook/latent/representation_similarity/cka.py @@ -0,0 +1,95 @@ +from textwrap import indent + +import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModuleBase + + +class CkaEstimator(TensorDictModuleBase): + """ + Centered kernel alignment (CKA) between two representations. + + Reads two data tensors from the input TensorDict. Expects `(N, D)` or + `(..., N, D)` for both tensors, with shared batch shape and sample count. + Outputs one scalar similarity value per batch item. + """ + + def __init__( + self, + in_key_a: str = "data_a", + in_key_b: str = "data_b", + out_key: str = "cka", + kernel: str = "linear", + eps: float = 1e-12, + ): + super().__init__() + if kernel != "linear": + raise NotImplementedError(f"Unsupported kernel '{kernel}'. Only 'linear' is implemented for now.") + self.in_key_a = in_key_a + self.in_key_b = in_key_b + self.out_key = out_key + self.kernel = kernel + self.eps = eps + self.in_keys = [in_key_a, in_key_b] + self.out_keys = [out_key] + + def forward(self, td: TensorDict) -> TensorDict: + x = td[self.in_key_a] + y = td[self.in_key_b] + _validate_inputs(x, y) + + batch_shape = x.shape[:-2] + n = x.shape[-2] + flat_x = x.reshape(-1, n, x.shape[-1]) + flat_y = y.reshape(-1, n, y.shape[-1]) + if flat_x.shape[0] == 0: + td[self.out_key] = torch.empty(batch_shape, dtype=torch.float32, device=x.device) + return td + cka_values = [_linear_cka(flat_x[i], flat_y[i], eps=self.eps) for i in range(flat_x.shape[0])] + td[self.out_key] = torch.stack(cka_values).reshape(batch_shape) + return td + + def __repr__(self): + fields = indent( + f"in_keys={self.in_keys},\nout_keys={self.out_keys},\nkernel='{self.kernel}',\neps={self.eps}", + 4 * " ", + ) + return f"{type(self).__name__}(\n{fields})" + + +def _validate_inputs(x: torch.Tensor, y: torch.Tensor) -> None: + if x.ndim < 2 or y.ndim < 2: + raise ValueError("CKA expects tensors with shape (N, D) or (..., N, D)") + if x.shape[:-2] != y.shape[:-2]: + raise ValueError(f"Expected matching batch shapes, got {x.shape[:-2]} and {y.shape[:-2]}") + if x.shape[-2] != y.shape[-2]: + raise ValueError(f"Expected matching sample counts, got {x.shape[-2]} and {y.shape[-2]}") + if x.device != y.device: + raise ValueError(f"Expected both tensors on the same device, got {x.device} and {y.device}") + + +def _linear_cka(x: torch.Tensor, y: torch.Tensor, eps: float) -> torch.Tensor: + dtype = torch.promote_types(x.dtype, y.dtype) + if not torch.empty((), dtype=dtype).is_floating_point(): + dtype = torch.float32 + + x = x.to(dtype=dtype) + y = y.to(dtype=dtype) + x = x - x.mean(dim=0, keepdim=True) + y = y - y.mean(dim=0, keepdim=True) + + cross_cov = x.transpose(-1, -2) @ y + x_cov = x.transpose(-1, -2) @ x + y_cov = y.transpose(-1, -2) @ y + + numerator = torch.sum(cross_cov.square()) + x_norm = torch.sum(x_cov.square()) + y_norm = torch.sum(y_cov.square()) + denominator = torch.sqrt(x_norm * y_norm) + + nan = torch.full((), float("nan"), dtype=dtype, device=x.device) + if not torch.isfinite(denominator) or denominator <= eps: + return nan + + value = numerator / denominator + return value.float() if torch.isfinite(value) else nan diff --git a/tests/latent/test_representation_similarity.py b/tests/latent/test_representation_similarity.py new file mode 100644 index 0000000..ec1bbb8 --- /dev/null +++ b/tests/latent/test_representation_similarity.py @@ -0,0 +1,171 @@ +""" +Tests for representation similarity estimators. +""" + +import pytest +import torch +from tensordict import TensorDict + +from tdhook.latent.representation_similarity import CkaEstimator + + +def make_td(x, y, in_key_a="data_a", in_key_b="data_b", batch_size=None): + if batch_size is None: + batch_size = [] if x.ndim == 2 else x.shape[:-2] + return TensorDict({in_key_a: x, in_key_b: y}, batch_size=batch_size) + + +def make_random_pair(n=64, d_x=10, d_y=7): + return torch.randn(n, d_x), torch.randn(n, d_y) + + +@pytest.fixture +def run_estimator(): + torch.manual_seed(42) + + def _run(x, y, in_key_a="data_a", in_key_b="data_b", batch_size=None, **estimator_kwargs): + td = make_td(x, y, in_key_a=in_key_a, in_key_b=in_key_b, batch_size=batch_size) + return CkaEstimator(in_key_a=in_key_a, in_key_b=in_key_b, **estimator_kwargs)(td) + + return _run + + +class TestCkaEstimator: + def test_default_keys(self, run_estimator): + x, y = make_random_pair() + + result = run_estimator(x, y) + + assert "cka" in result + assert result["cka"].ndim == 0 + assert result["cka"].dtype in (torch.float32, torch.float64) + assert torch.isfinite(result["cka"]) + + def test_custom_keys(self, run_estimator): + x = torch.randn(48, 8) + y = torch.randn(48, 6) + + result = run_estimator(x, y, in_key_a="linear1", in_key_b="linear2", out_key="similarity") + + assert "linear1" in result + assert "linear2" in result + assert "similarity" in result + assert result["similarity"].ndim == 0 + + def test_identical_views_are_one(self, run_estimator): + x = torch.randn(128, 16) + + result = run_estimator(x, x.clone()) + + assert torch.isclose(result["cka"], torch.tensor(1.0), atol=1e-6) + + def test_invariant_to_isotropic_scaling(self, run_estimator): + x = torch.randn(128, 16) + y = 7.5 * x + + result = run_estimator(x, y) + + assert torch.isclose(result["cka"], torch.tensor(1.0), atol=1e-5) + + def test_invariant_to_orthogonal_rotation(self, run_estimator): + x = torch.randn(128, 12) + q, _ = torch.linalg.qr(torch.randn(12, 12)) + y = x @ q + + result = run_estimator(x, y) + + assert torch.isclose(result["cka"], torch.tensor(1.0), atol=1e-5) + + def test_independent_random_views_have_low_cka(self, run_estimator): + x = torch.randn(512, 32) + y = torch.randn(512, 24) + + result = run_estimator(x, y) + + assert result["cka"].item() < 0.2 + + @pytest.mark.parametrize( + ("x_shape", "y_shape"), + [ + ((1, 10, 8), (1, 10, 6)), + ((5, 10, 8), (5, 10, 6)), + ((2, 3, 10, 8), (2, 3, 10, 6)), + ], + ids=["1x10", "5x10", "2x3x10"], + ) + def test_batch_shape_preservation(self, run_estimator, x_shape, y_shape): + x = torch.randn(*x_shape) + y = torch.randn(*y_shape) + batch_size = x_shape[:-2] + + result = run_estimator(x, y, batch_size=batch_size) + + assert result["cka"].shape == batch_size + + def test_empty_flattened_batch_returns_empty_output(self, run_estimator): + x = torch.randn(2, 0, 10, 8) + y = torch.randn(2, 0, 10, 6) + + result = run_estimator(x, y, batch_size=[2, 0]) + + assert result["cka"].shape == (2, 0) + assert result["cka"].dtype == torch.float32 + assert result["cka"].numel() == 0 + + def test_mismatched_sample_counts_raise(self, run_estimator): + with pytest.raises(ValueError, match="matching sample counts"): + run_estimator(torch.randn(32, 8), torch.randn(31, 6)) + + def test_mismatched_batch_shapes_raise(self, run_estimator): + with pytest.raises(ValueError, match="matching batch shapes"): + run_estimator(torch.randn(2, 3, 16, 8), torch.randn(2, 4, 16, 6), batch_size=[2]) + + def test_invalid_rank_raises(self, run_estimator): + with pytest.raises(ValueError, match=r"shape \(N, D\) or \(\.\.\., N, D\)"): + run_estimator(torch.randn(32), torch.randn(32)) + + def test_mismatched_devices_raise(self, run_estimator): + x = torch.randn(32, 8) + y = torch.randn(32, 6, device="meta") + + with pytest.raises(ValueError, match="same device"): + run_estimator(x, y) + + def test_constant_representation_returns_nan(self, run_estimator): + x = torch.ones(64, 8) + y = torch.randn(64, 6) + + result = run_estimator(x, y) + + assert torch.isnan(result["cka"]) + + def test_integer_inputs_are_promoted_to_float32(self, run_estimator): + x = torch.arange(512, dtype=torch.int64).reshape(128, 4) + + result = run_estimator(x, x.clone()) + + assert result["cka"].dtype == torch.float32 + assert torch.isclose(result["cka"], torch.tensor(1.0, dtype=torch.float32), atol=1e-6) + + def test_determinism(self, run_estimator): + x = torch.randn(96, 10) + y = torch.randn(96, 7) + + r1 = run_estimator(x.clone(), y.clone())["cka"] + r2 = run_estimator(x.clone(), y.clone())["cka"] + + assert torch.allclose(r1, r2, equal_nan=True) + + def test_repr(self): + est = CkaEstimator() + r = repr(est) + + assert "CkaEstimator" in r + assert "in_keys=['data_a', 'data_b']" in r + assert "out_keys=['cka']" in r + assert "kernel='linear'" in r + assert "eps=" in r + + def test_unknown_kernel_raises(self): + with pytest.raises(NotImplementedError, match="Only 'linear' is implemented"): + CkaEstimator(kernel="rbf")