From 947a11dd2f0914f4e0d64e48f92d10d1acc05870 Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Wed, 10 Jun 2026 18:48:50 +0000 Subject: [PATCH 1/3] test(foundry): add instantiators unit tests and bring under strict mypy instantiators is already fully annotated, so this adds it to the per-module strict mypy override (a lock-in for future defs) and fills the test gap with tests/test_instantiators.py covering the callback/logger instantiation control flow. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- pyproject.toml | 1 + tests/test_instantiators.py | 81 +++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 tests/test_instantiators.py diff --git a/pyproject.toml b/pyproject.toml index e9ff2c44..daa46b50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -274,6 +274,7 @@ module = [ "foundry.utils.alignment", "foundry.utils.weights", "foundry.utils.rotation_augmentation", + "foundry.utils.instantiators", ] disallow_untyped_defs = true check_untyped_defs = true diff --git a/tests/test_instantiators.py b/tests/test_instantiators.py new file mode 100644 index 00000000..59c378fd --- /dev/null +++ b/tests/test_instantiators.py @@ -0,0 +1,81 @@ +"""Unit tests for foundry.utils.instantiators. + +These helpers turn a hydra config group into a list of instantiated objects. +The contract worth pinning is the control flow, not the object type: a missing / +empty config yields an empty list, each sub-config is instantiated via its +``_target_``, and a sub-config that is not an instantiable ``DictConfig`` (no +``_target_`` key) raises ``InstantiationError``. The functions do not themselves +check that the result is a callback / logger, so the tests use a lightweight +stdlib target (``types.SimpleNamespace``) to exercise that flow directly. +""" + +from types import SimpleNamespace + +import pytest +from omegaconf import OmegaConf + +from foundry.utils.instantiators import ( + InstantiationError, + _can_be_instantiated, + instantiate_callbacks, + instantiate_loggers, +) + +_TARGET = "types.SimpleNamespace" + + +def test_can_be_instantiated_true_with_target(): + assert _can_be_instantiated(OmegaConf.create({"_target_": _TARGET})) is True + + +def test_can_be_instantiated_false_without_target(): + assert _can_be_instantiated(OmegaConf.create({"x": 1})) is False + + +def test_can_be_instantiated_false_for_non_dictconfig(): + """A plain dict is not a DictConfig, so it is not instantiable.""" + assert _can_be_instantiated({"_target_": _TARGET}) is False + + +def test_instantiate_callbacks_none_returns_empty(): + assert instantiate_callbacks(None) == [] + + +def test_instantiate_callbacks_empty_config_returns_empty(): + assert instantiate_callbacks(OmegaConf.create({})) == [] + + +def test_instantiate_callbacks_builds_each_target_in_order(): + cfg = OmegaConf.create( + { + "first": {"_target_": _TARGET, "x": 1}, + "second": {"_target_": _TARGET, "x": 2}, + } + ) + result = instantiate_callbacks(cfg) + assert result == [SimpleNamespace(x=1), SimpleNamespace(x=2)] + + +def test_instantiate_callbacks_raises_on_missing_target(): + cfg = OmegaConf.create({"bad": {"x": 1}}) + with pytest.raises(InstantiationError): + instantiate_callbacks(cfg) + + +def test_instantiate_loggers_none_returns_empty(): + assert instantiate_loggers(None) == [] + + +def test_instantiate_loggers_builds_target(): + cfg = OmegaConf.create({"logger": {"_target_": _TARGET, "name": "run"}}) + assert instantiate_loggers(cfg) == [SimpleNamespace(name="run")] + + +def test_instantiate_loggers_raises_on_missing_target(): + cfg = OmegaConf.create({"bad": {"name": "run"}}) + with pytest.raises(InstantiationError): + instantiate_loggers(cfg) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) From 198466fd78e69e3f9025a7fdef813ce6050caf2b Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Wed, 10 Jun 2026 22:21:34 +0000 Subject: [PATCH 2/3] chore(foundry): bring ddp, squashfs, logging under strict mypy and test logging Adds the three remaining small foundry.utils modules to the per-module strict mypy override. squashfs was already fully annotated (pure lock-in); ddp needed one annotation (RankedLogger.log varargs); logging needed five. New tests/test_logging.py covers the two pure helpers (CachedDataFilter.filter and condense_count_columns_of_grouped_df). Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- pyproject.toml | 3 ++ src/foundry/utils/ddp.py | 2 +- src/foundry/utils/logging.py | 12 +++--- tests/test_logging.py | 76 ++++++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 7 deletions(-) create mode 100644 tests/test_logging.py diff --git a/pyproject.toml b/pyproject.toml index daa46b50..0d7fd757 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -275,6 +275,9 @@ module = [ "foundry.utils.weights", "foundry.utils.rotation_augmentation", "foundry.utils.instantiators", + "foundry.utils.ddp", + "foundry.utils.squashfs", + "foundry.utils.logging", ] disallow_untyped_defs = true check_untyped_defs = true diff --git a/src/foundry/utils/ddp.py b/src/foundry/utils/ddp.py index fef47bd3..0dcbe16f 100644 --- a/src/foundry/utils/ddp.py +++ b/src/foundry/utils/ddp.py @@ -79,7 +79,7 @@ def __init__( self.rank_zero_only = rank_zero_only def log( # type: ignore[override] # deliberately extends LoggerAdapter.log with a `rank` parameter - self, level: int, msg: str, rank: int | None = None, *args, **kwargs + self, level: int, msg: str, rank: int | None = None, *args: Any, **kwargs: Any ) -> None: """ Delegate a log call to the underlying logger, after prefixing its message with the rank diff --git a/src/foundry/utils/logging.py b/src/foundry/utils/logging.py index 95b580e3..e6a8b3bc 100755 --- a/src/foundry/utils/logging.py +++ b/src/foundry/utils/logging.py @@ -3,7 +3,7 @@ from contextlib import contextmanager import pandas as pd -from beartype.typing import Any +from beartype.typing import Any, Iterator from lightning.fabric.utilities import rank_zero_only from omegaconf import DictConfig, OmegaConf from rich.console import Console @@ -20,14 +20,14 @@ class CachedDataFilter(logging.Filter): """Filter to suppress atomworks cached data logging messages.""" - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: # Filter out "Cached data not found" messages if "Cached data not found" in record.getMessage(): return False return True -def silence_warnings(): +def silence_warnings() -> None: """Silence common warnings that appear during foundry execution.""" warnings.filterwarnings( "ignore", message="All-NaN slice encountered", category=RuntimeWarning @@ -67,7 +67,7 @@ def silence_warnings(): @contextmanager -def suppress_warnings(is_inference: bool = False): +def suppress_warnings(is_inference: bool = False) -> Iterator[None]: """Context manager to suppress specific warnings within its scope. Args: @@ -178,7 +178,7 @@ def print_model_parameters(model: nn.Module, name: str = "") -> None: def log_hyperparameters_with_all_loggers( trainer: Any, cfg: dict | DictConfig, model: Any -): +) -> None: """Logs hyperparameters using all loggers in the trainer. Args: @@ -260,7 +260,7 @@ def table_from_df(df: pd.DataFrame, title: str) -> Table: return table -def safe_print(obj: Any, console_width=100, logger: Any | None = None) -> None: +def safe_print(obj: Any, console_width: int = 100, logger: Any | None = None) -> None: """Print a Rich object in a console- and logger-safe manner.""" console = Console(force_terminal=False, color_system=None, width=console_width) diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 00000000..e81285d4 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,76 @@ +"""Unit tests for the pure helpers in foundry.utils.logging. + +The display/configuration functions in this module are side-effecting glue +(Rich console output, warning filters, logger levels). The two pieces with a +non-obvious, environment-independent contract are pinned here: + +- ``CachedDataFilter`` suppresses a specific atomworks log line by substring. +- ``condense_count_columns_of_grouped_df`` collapses the repeated per-metric + ``count`` columns of a grouped (MultiIndex-column) DataFrame into one ``Count`` + column — but only when the count is identical across metrics in every row, + and only for a MultiIndex frame with both ``count`` and ``mean`` sub-levels. +""" + +import logging + +import pandas as pd +import pytest + +from foundry.utils.logging import ( + CachedDataFilter, + condense_count_columns_of_grouped_df, +) + + +def _record(msg: str) -> logging.LogRecord: + return logging.LogRecord("test", logging.INFO, __file__, 1, msg, None, None) + + +def test_cached_data_filter_suppresses_cached_data_message(): + assert ( + CachedDataFilter().filter(_record("Cached data not found at /tmp/x")) is False + ) + + +def test_cached_data_filter_keeps_unrelated_message(): + assert CachedDataFilter().filter(_record("Loaded 12 structures")) is True + + +def _grouped(rows: list[list[float]]) -> pd.DataFrame: + """Frame with MultiIndex columns (metric, {count,mean}) for two metrics.""" + cols = pd.MultiIndex.from_tuples( + [("a", "count"), ("a", "mean"), ("b", "count"), ("b", "mean")] + ) + return pd.DataFrame(rows, columns=cols) + + +def test_condense_returns_non_multiindex_frame_unchanged(): + df = pd.DataFrame({"x": [1, 2], "y": [3, 4]}) + assert condense_count_columns_of_grouped_df(df) is df + + +def test_condense_collapses_consistent_counts(): + df = _grouped([[5, 1.0, 5, 2.0], [3, 0.5, 3, 1.5]]) + result = condense_count_columns_of_grouped_df(df) + + assert list(result.columns) == ["a (mean)", "b (mean)", "Count"] + assert result["Count"].tolist() == [5, 3] + assert result["a (mean)"].tolist() == [1.0, 0.5] + assert result["b (mean)"].tolist() == [2.0, 1.5] + + +def test_condense_leaves_frame_when_counts_disagree_within_a_row(): + """Row 0's metrics have counts 5 vs 6, so the frame is returned untouched.""" + df = _grouped([[5, 1.0, 6, 2.0]]) + assert condense_count_columns_of_grouped_df(df) is df + + +def test_condense_leaves_frame_without_a_count_sublevel(): + """MultiIndex columns lacking a 'count' level raise KeyError -> returned as-is.""" + cols = pd.MultiIndex.from_tuples([("a", "total"), ("a", "mean")]) + df = pd.DataFrame([[5, 1.0]], columns=cols) + assert condense_count_columns_of_grouped_df(df) is df + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) From 068a07a4fd5c0927477f50ffa6fb48b20836596e Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Fri, 12 Jun 2026 19:43:55 +0000 Subject: [PATCH 3/3] chore(rf3): remove dead-and-broken rf3.data.paired_msa module rf3.data.paired_msa no longer imports against the installed atomworks: MultiInputDatasetWrapper subclasses StructuralDatasetWrapper, which atomworks turned into a deprecated factory function, so subclassing it raises TypeError at import. The module was reachable only through domain_distillation.yaml, which is itself referenced only in commented-out lines of pdb_and_distillation.yaml, and its LoadPairedMSAs class was used nowhere. Remove the module, its orphaned domain_distillation.yaml config, the dangling commented references in pdb_and_distillation.yaml, and the stale pyproject mypy-exemption comment. rf3 now type-checks fully with no in-file suppressions. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- .../datasets/pdb_and_distillation.yaml | 3 - .../datasets/train/domain_distillation.yaml | 50 ---- models/rf3/src/rf3/data/paired_msa.py | 217 ------------------ pyproject.toml | 5 +- 4 files changed, 1 insertion(+), 274 deletions(-) delete mode 100644 models/rf3/configs/datasets/train/domain_distillation.yaml delete mode 100644 models/rf3/src/rf3/data/paired_msa.py diff --git a/models/rf3/configs/datasets/pdb_and_distillation.yaml b/models/rf3/configs/datasets/pdb_and_distillation.yaml index 505513cc..d85c46e3 100644 --- a/models/rf3/configs/datasets/pdb_and_distillation.yaml +++ b/models/rf3/configs/datasets/pdb_and_distillation.yaml @@ -9,7 +9,6 @@ defaults: - monomer_distillation - na_complex_distillation - disorder_distillation - # - domain_distillation # - rna_monomer_distillation - val/af3_validation@val.af3_validation - val/af3_validation@val.quick_af3_validation_with_templating @@ -28,8 +27,6 @@ train: probability: 0.02 disorder_distillation: probability: 0.02 - # multidomain_distillation: - # probability: 0.06 # rna_monomer_distillation: # probability: 0.04 diff --git a/models/rf3/configs/datasets/train/domain_distillation.yaml b/models/rf3/configs/datasets/train/domain_distillation.yaml deleted file mode 100644 index 73924aa6..00000000 --- a/models/rf3/configs/datasets/train/domain_distillation.yaml +++ /dev/null @@ -1,50 +0,0 @@ -# TODO: Inherit from common config with default Transform pipeline - -multidomain_distillation: - dataset: - _target_: rf3.data.paired_msa.MultiInputDatasetWrapper - save_failed_examples_to_dir: null - - # cif parser - cif_parser_args: - #assume_residues_all_resolved: true - cache_dir: null - load_from_cache: false - save_to_cache: false - - # metadata parser - dataset_parser: - _target_: rf3.data.paired_msa.MultidomainDFParser - - # metadata dataset - dataset: - _target_: atomworks.ml.datasets.PandasDataset - name: multidomain_distillation - id_column: example_id - data: /projects/ml/datahub/dfs/domain_domain/domain_domain_dataset.DIGS.parquet - columns_to_load: - - example_id - - pdb_path - - msa_path - transform: - _target_: ${datasets.pipeline_target} - is_inference: False - input_contains_explicit_msa: True - protein_msa_dirs: [] - rna_msa_dirs: [] - n_recycles: ${datasets.n_recycles_train} - crop_size: ${datasets.crop_size} - n_msa: ${datasets.n_msa} - diffusion_batch_size: ${datasets.diffusion_batch_size_train} - max_atoms_in_crop: ${datasets.max_atoms_in_crop} - crop_contiguous_probability: 0.25 - crop_spatial_probability: 0.75 - run_confidence_head: ${datasets.run_confidence_head} - take_first_chiral_subordering: ${datasets.take_first_chiral_subordering} - use_element_for_atom_names_of_atomized_tokens: ${datasets.use_element_for_atom_names_of_atomized_tokens} - mirror_prob: 0.0 - atomization_prob: ${datasets.atomization_prob} - ligand_dropout_prob: 0.0 - p_unconditional: ${datasets.p_unconditional} - p_dropout_atom_level_embeddings: ${datasets.p_dropout_atom_level_embeddings} - add_residue_is_paired_feature: ${datasets.add_residue_is_paired_feature} diff --git a/models/rf3/src/rf3/data/paired_msa.py b/models/rf3/src/rf3/data/paired_msa.py deleted file mode 100644 index 8e7b81c8..00000000 --- a/models/rf3/src/rf3/data/paired_msa.py +++ /dev/null @@ -1,217 +0,0 @@ -# mypy: ignore-errors -# -# This module does not type-check (and does not even import) against the installed -# atomworks: `MultiInputDatasetWrapper` below subclasses -# `atomworks.ml.datasets.StructuralDatasetWrapper`, which atomworks turned into a -# deprecated factory *function* — subclassing it raises `TypeError` at import time. -# Making it type-check requires a real refactor onto the `PandasDataset` API, validated -# on cluster data (see `.ai/roadmap.md`), not type annotations. The suppression lives -# here, in the file, rather than in `pyproject.toml` so it is visible to anyone reviving -# the module: when this file imports and type-checks cleanly again, delete this directive -# to restore mypy coverage (the module stays inside mypy's `files` scope). -import os -import socket -import time -from pathlib import Path -from typing import Any - -import numpy as np -from atomworks.common import exists -from atomworks.enums import ChainType -from atomworks.ml.datasets import StructuralDatasetWrapper, logger -from atomworks.ml.datasets.parsers import ( - MetadataRowParser, - load_example_from_metadata_row, -) -from atomworks.ml.transforms._checks import ( - check_contains_keys, - check_is_instance, - check_nonzero_length, -) -from atomworks.ml.transforms.base import Transform, TransformedDict -from atomworks.ml.transforms.msa._msa_loading_utils import load_msa_data_from_path -from atomworks.ml.utils.rng import capture_rng_states -from biotite.structure import AtomArray, concatenate - - -# input data wrapper that allows multiple input files separated by ':' -# data is loaded as concatentation of all inputs -class MultiInputDatasetWrapper(StructuralDatasetWrapper): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def __getitem__(self, idx: int) -> Any: - # Capture example ID & current rng state (for reproducibility & debugging) - if hasattr(self, "idx_to_id"): - # ...if the dataset has a custom idx_to_id method, use it (e.g., for a PandasDataset) - example_id = self.idx_to_id(idx) - else: - # ...otherwise, fallback to a the `id_column` or a string representation of the index - example_id = ( - self.dataset[idx][self.id_column] if self.id_column else f"row_{idx}" - ) - - # Get process id and hostname (for debugging) - logger.debug( - f"({socket.gethostname()}:{os.getpid()}) Processing example ID: {example_id}" - ) - - # Load the row, using the __getitem__ method of the dataset - row = self.dataset[idx] - pdb_path = row["pdb_path"].split(":") - - # Process the row into a transform-ready dictionary with the given CIF and dataset parsers - # We require the "data" dictionary output from `load_example_from_metadata_row` to contain, at a minimum: - # (a) An "id" key, which uniquely identifies the example within the dataframe; and, - # (b) The "path" key, which is the path to the CIF file - _start_parse_time = time.time() - data = None - assert len(pdb_path) <= 2 - - for pdb_i in pdb_path: - row_i = {"example_id": row["example_id"], "path": pdb_i} - data_i = load_example_from_metadata_row( - row_i, self.dataset_parser, cif_parser_args=self.cif_parser_args - ) - - if data is None: - data = data_i - else: - data_i["atom_array"].pn_unit_id = np.full( - len(data_i["atom_array"]), "B_1" - ) # unique pn unit id - data_i["atom_array"].pn_unit_iid = np.full( - len(data_i["atom_array"]), "B_1" - ) # unique pn unit iid - data_i["atom_array"].chain_id = np.full( - len(data_i["atom_array"]), "B" - ) # unique chain id - data_i["atom_array"].chain_iid = np.full( - len(data_i["atom_array"]), "B" - ) # unique chain iid - data["atom_array"] = concatenate( - [data["atom_array"], data_i["atom_array"]] - ) - data["atom_array_stack"] = concatenate( - [data["atom_array_stack"], data_i["atom_array_stack"]] - ) - data["chain_info"]["B"] = data_i["chain_info"]["A"] - - # 'example_id', 'path', 'assembly_id', 'query_pn_unit_iids', - data["path"] = row["pdb_path"] - data["msa_path"] = Path(row["msa_path"]) # save msa - _stop_parse_time = time.time() - - # Manually add timing for cif-parsing - data = TransformedDict(data) - data.__transform_history__.append( - dict( - name="load_example_from_metadata_row", - instance=hex(id(load_example_from_metadata_row)), - start_time=_start_parse_time, - end_time=_stop_parse_time, - processing_time=_stop_parse_time - _start_parse_time, - ) - ) - - # Apply the transformation pipeline to the data - if exists(self.transform): - try: - rng_state_dict = capture_rng_states(include_cuda=False) - data = self.transform(data) - except KeyboardInterrupt as e: - raise e - except Exception as e: - # Log the error and save the failed example to disk (optional) - logger.info(f"Error processing row {idx} ({example_id}): {e}") - - if exists(self.save_failed_examples_to_dir): - save_failed_example_to_disk( - example_id=example_id, - error_msg=e, - rng_state_dict=rng_state_dict, - data={}, # We do not save the data, since it may be large. - fail_dir=self.save_failed_examples_to_dir, - ) - raise e - - return data - - -class MultidomainDFParser(MetadataRowParser): - """Parser for Qian's multidomain data""" - - def __init__( - self, - example_id_colname: str = "example_id", - path_colname: str = "path", - ): - self.example_id_colname = example_id_colname - self.path_colname = path_colname - - def _parse(self, row: dict) -> dict[str, Any]: - query_pn_unit_iids = None - assembly_id = "1" - - return { - "example_id": row[self.example_id_colname], - "path": Path(row[self.path_colname]), - "assembly_id": assembly_id, - "query_pn_unit_iids": query_pn_unit_iids, - "extra_info": row, - } - - -class LoadPairedMSAs(Transform): - """ - LoadPairedMSAs adds paired MSAs from disk, overwriting previously paired MSA data. - """ - - def check_input(self, data: dict[str, Any]): - check_contains_keys(data, ["atom_array", "msa_path"]) - check_is_instance(data, "atom_array", AtomArray) - check_nonzero_length(data, "atom_array") - - def forward(self, data: dict[str, Any]) -> dict[str, Any]: - atom_array = data["atom_array"] - msa_file_path = data["msa_path"] - chain_type = data["chain_info"]["A"]["chain_type"] - max_msa_sequences = 10000 - - msa_data = load_msa_data_from_path( - msa_file_path=msa_file_path, - chain_type=chain_type, - max_msa_sequences=max_msa_sequences, - ) - - # split into chains - start_idx = 0 - allpolymerchains = np.unique( - atom_array.chain_id[ - np.isin(atom_array.chain_type, ChainType.get_polymers()) - ] - ) - - data["polymer_msas_by_chain_id"] = {} # nuke old version - for chain_id in allpolymerchains: - sequence = data["chain_info"][chain_id][ - "processed_entity_non_canonical_sequence" - ] - stop_idx = start_idx + len(sequence) - - data["polymer_msas_by_chain_id"][chain_id] = {} - - # trim all msa info to this chain only - for mkey in msa_data.keys(): - data["polymer_msas_by_chain_id"][chain_id][mkey] = msa_data[mkey][ - ..., start_idx:stop_idx - ] - - # mock msa_is_padded_mask (all 0s) - data["polymer_msas_by_chain_id"][chain_id]["msa_is_padded_mask"] = np.zeros( - data["polymer_msas_by_chain_id"][chain_id]["msa"].shape, dtype=bool - ) - - start_idx = stop_idx - - return data diff --git a/pyproject.toml b/pyproject.toml index 0d7fd757..ee584998 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -258,10 +258,7 @@ module = [ ignore_errors = true # NOTE: the rf3 enablement ratchet (0014) is fully cleared — there is no rf3 module-level -# mypy exemption here. The one module that cannot type-check, `rf3.data.paired_msa` -# (broken against the installed atomworks; needs a `PandasDataset`-API refactor), carries -# a file-level `# mypy: ignore-errors` directive in the module itself, so the suppression -# is visible where the code is and the module stays inside mypy's `files` scope. +# mypy exemption here, and every rf3 module type-checks with no in-file suppressions. # Per-module strictness ratchet (direction (b)). The global baseline above leaves # disallow_untyped_defs / check_untyped_defs off; fully-annotated modules opt into strict