Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ instance/

# Sphinx documentation
docs/_build/
docs/source/api/

# PyBuilder
target/
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,5 @@
# Autoapi
autoapi_dirs = ["../../src"]
autoapi_root = "api"
autoapi_keep_files = False
autoapi_keep_files = True
autodoc_typehints = "description"
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
:sd_hide_title:

`tdhook`
=====
========

.. toctree::
:maxdepth: 1
Expand Down
18 changes: 18 additions & 0 deletions docs/source/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,23 @@ Methods
</div>
</div>

.. grid-item-card::
:link: notebooks/methods/representation-similarity.ipynb
:class-card: surface
:class-body: surface

.. raw:: html

<div class="d-flex align-items-center">
<div class="d-flex justify-content-center" style="min-width: 50px; margin-right: 20px; height: 100%;">
<i class="fa-solid fa-code-compare fa-2x"></i>
</div>
<div>
<h5 class="card-title">Representation Similarity</h5>
<p class="card-text">Compare latent representations with CKA and leave room for additional similarity metrics.</p>
</div>
</div>

.. toctree::
:hidden:
:maxdepth: 2
Expand All @@ -114,3 +131,4 @@ Methods
notebooks/methods/linear-probing.ipynb
notebooks/methods/bilinear-probing.ipynb
notebooks/methods/dimension-estimation.ipynb
notebooks/methods/representation-similarity.ipynb
197 changes: 197 additions & 0 deletions docs/source/notebooks/methods/representation-similarity.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Representation Similarity\n",
"\n",
"This notebook introduces representation similarity methods in `tdhook`.\n",
"\n",
"It currently starts with centered kernel alignment (CKA) through `tdhook.latent.representation_similarity.CkaEstimator`. More similarity methods can be added here later as the module grows."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import importlib.util\n",
"\n",
"DEV = True\n",
Comment thread
Xmaster6y marked this conversation as resolved.
"\n",
"if importlib.util.find_spec(\"google.colab\") is not None:\n",
" MODE = \"colab-dev\" if DEV else \"colab\"\n",
"else:\n",
" MODE = \"local\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"if MODE == \"colab\":\n",
" %pip install -q tdhook\n",
"elif MODE == \"colab-dev\":\n",
" !rm -rf tdhook\n",
" !git clone https://github.com/Xmaster6y/tdhook -b main\n",
" %pip install -q ./tdhook"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from tensordict import TensorDict\n",
"\n",
"from tdhook.latent.representation_similarity import CkaEstimator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Synthetic Example\n",
"\n",
"We build a few pairs of representations with known relationships:\n",
"\n",
"- `same`: identical representations, so CKA should be close to `1`\n",
"- `rotated`: an orthogonal transform of the same representation, which linear CKA should also score near `1`\n",
"- `random`: an unrelated representation, which should typically score much lower"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(0)\n",
"\n",
"x = torch.randn(256, 32)\n",
"q, _ = torch.linalg.qr(torch.randn(32, 32))\n",
"\n",
"examples = {\n",
" \"same\": (x, x.clone()),\n",
" \"rotated\": (x, x @ q),\n",
" \"random\": (x, torch.randn(256, 24)),\n",
"}\n",
"\n",
"estimator = CkaEstimator(kernel=\"linear\")\n",
"\n",
"\n",
"def run_cka(x, y):\n",
" td = TensorDict({\"data_a\": x, \"data_b\": y}, batch_size=[])\n",
" return estimator(td.clone())[\"cka\"].item()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'same': 1.0, 'rotated': 1.0, 'random': 0.10095701366662979}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scores = {name: run_cka(a, b) for name, (a, b) in examples.items()}\n",
"scores"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Batched Inputs\n",
"\n",
"Like the dimension-estimation modules, `CkaEstimator` accepts either `(N, D)` or batched `(..., N, D)` inputs and returns one scalar score per batch item."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.9979, 0.9978, 0.9976])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batched_x = torch.randn(3, 128, 16)\n",
"batched_y = batched_x + 0.05 * torch.randn(3, 128, 16)\n",
"\n",
"td = TensorDict({\"data_a\": batched_x, \"data_b\": batched_y}, batch_size=[3])\n",
"CkaEstimator()(td.clone())[\"cka\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API Notes\n",
"\n",
"- The estimator is named `CkaEstimator` and already exposes a `kernel` argument.\n",
"- At the moment only `kernel=\"linear\"` is implemented.\n",
"- Degenerate inputs with zero variance return `nan` instead of raising.\n",
"\n",
"Future methods can extend this notebook with additional sections, comparisons, and visualizations."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
4 changes: 3 additions & 1 deletion src/tdhook/attribution/gradient_helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class Riemann(Enum):
"""Supported Riemann integration variants for gradient approximation."""

left = 1
right = 2
middle = 3
Expand Down Expand Up @@ -42,7 +44,7 @@ def approximation_parameters(
return riemann_builders(method=Riemann[method.split("_")[-1]])
if method == "gausslegendre":
return gauss_legendre_builders()
raise ValueError("Invalid integral approximation method name: {}".format(method))
raise ValueError(f"Invalid integral approximation method name: {method}")


def riemann_builders(
Expand Down
13 changes: 10 additions & 3 deletions src/tdhook/attribution/lrp_helpers/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ class AbstractFunctionMeta(ABCMeta, FunctionMeta):


class Rule(Function, metaclass=AbstractFunctionMeta):
"""Base class for LRP rules implemented as custom autograd functions.

Subclasses override `forward()` and `backward()` to define how relevance is
propagated through a wrapped module. Instances are registered onto modules
by temporarily replacing `module.forward` with `Function.apply(...)`.
"""

def __init__(self):
self._apply_kwargs = {}
# TODO: Add zero_params argument for all rules
Expand All @@ -116,11 +123,13 @@ def unregister(self, module: nn.Module):
@staticmethod
@abstractmethod
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
"""Run the wrapped module and save any tensors needed for relevance propagation."""
pass

@staticmethod
@abstractmethod
def backward(ctx, *out_relevance):
"""Propagate output relevance back to the inputs of the wrapped module."""
pass


Expand Down Expand Up @@ -454,9 +463,7 @@ def _call(self, name: str, module: nn.Module) -> Rule | None:

def __call__(self, name: str, module: nn.Module) -> Rule | None:
rule = self._rule_mapper(name, module)
if rule is None:
return self._call(name, module)
return rule
return self._call(name, module) if rule is None else rule


class EpsilonPlus(BaseRuleMapper):
Expand Down
2 changes: 2 additions & 0 deletions src/tdhook/latent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Probe,
ProbeManager,
)
from .representation_similarity import CkaEstimator
from .steering_vectors import SteeringVectors, ActivationAddition

__all__ = [
Expand All @@ -23,6 +24,7 @@
"ActivationPatching",
"BilinearProbe",
"BilinearProbeManager",
"CkaEstimator",
"LinearEstimator",
"LowRankBilinearEstimator",
"LocalKnnDimensionEstimator",
Expand Down
3 changes: 0 additions & 3 deletions src/tdhook/latent/representation_similarity.py

This file was deleted.

7 changes: 7 additions & 0 deletions src/tdhook/latent/representation_similarity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
Representation similarity methods.
"""

from .cka import CkaEstimator

__all__ = ["CkaEstimator"]
Loading
Loading