From f007dddc9797f418ef784d18ce67b0f798c9eeb6 Mon Sep 17 00:00:00 2001 From: "Harlow, Jordan" Date: Thu, 12 Feb 2026 13:38:14 -0700 Subject: [PATCH 1/5] task: add patch methods for mkl_random --- mkl_random/__init__.py | 2 + mkl_random/src/_patch.pyx | 275 +++++++++++++++++++++++++++++++++ mkl_random/tests/test_patch.py | 95 ++++++++++++ setup.py | 8 + 4 files changed, 380 insertions(+) create mode 100644 mkl_random/src/_patch.pyx create mode 100644 mkl_random/tests/test_patch.py diff --git a/mkl_random/__init__.py b/mkl_random/__init__.py index 512027b..01ee956 100644 --- a/mkl_random/__init__.py +++ b/mkl_random/__init__.py @@ -42,4 +42,6 @@ test = PytestTester(__name__) del PytestTester +from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random + del _init_helper diff --git a/mkl_random/src/_patch.pyx b/mkl_random/src/_patch.pyx new file mode 100644 index 0000000..6c39ff3 --- /dev/null +++ b/mkl_random/src/_patch.pyx @@ -0,0 +1,275 @@ +# Copyright (c) 2019, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# distutils: language = c +# cython: language_level=3 + +""" +Patch NumPy's `numpy.random` symbols to use mkl_random implementations. + +This is attribute-level monkey patching. It can replace legacy APIs like +`numpy.random.RandomState` and global distribution functions, but it does not +replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully +compatible replacements. +""" + +from threading import local as threading_local +from contextlib import ContextDecorator + +import numpy as _np +from . import mklrand as _mr + + +cdef tuple _DEFAULT_NAMES = ( + # Legacy seeding / state + "seed", + "get_state", + "set_state", + "RandomState", + + # Common global sampling helpers + "random", + "random_sample", + "sample", + "rand", + "randn", + "bytes", + + # Integers + "randint", + + # Common distributions (only patched if present on both sides) + "standard_normal", + "normal", + "uniform", + "exponential", + "gamma", + "beta", + "chisquare", + "f", + "lognormal", + "laplace", + "logistic", + "multivariate_normal", + "poisson", + "power", + "rayleigh", + "triangular", + "vonmises", + "wald", + "weibull", + "zipf", + + # Permutations / choices + "choice", + "permutation", + "shuffle", +) + + +cdef class patch: + cdef bint _is_patched + cdef object _numpy_module + cdef object _originals # dict: name -> original object + cdef object _patched # list of names actually patched + + def __cinit__(self): + self._is_patched = False + self._numpy_module = None + self._originals = {} + self._patched = [] + + def do_patch(self, numpy_module=None, names=None, bint strict=False): + """ + Patch the given numpy module (default: imported numpy) in-place. + + Parameters + ---------- + numpy_module : module, optional + The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`). + names : iterable[str], optional + Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES. + strict : bool + If True, raise if any requested symbol cannot be patched. + """ + if numpy_module is None: + numpy_module = _np + if names is None: + names = _DEFAULT_NAMES + + if not hasattr(numpy_module, "random"): + raise TypeError("Expected a numpy-like module with a `.random` attribute.") + + # If already patched, only allow idempotent re-entry for the same numpy module. + if self._is_patched: + if self._numpy_module is numpy_module: + return + raise RuntimeError("Already patched a different numpy module; call restore() first.") + + np_random = numpy_module.random + + originals = {} + patched = [] + missing = [] + + for name in names: + if not hasattr(np_random, name) or not hasattr(_mr, name): + missing.append(name) + continue + originals[name] = getattr(np_random, name) + setattr(np_random, name, getattr(_mr, name)) + patched.append(name) + + if strict and missing: + # revert partial patch before raising + for n, v in originals.items(): + setattr(np_random, n, v) + raise AttributeError( + "Could not patch these names (missing on numpy.random or mkl_random.mklrand): " + + ", ".join([str(x) for x in missing]) + ) + + self._numpy_module = numpy_module + self._originals = originals + self._patched = patched + self._is_patched = True + + def do_unpatch(self): + """ + Restore the previously patched numpy module. + """ + if not self._is_patched: + return + numpy_module = self._numpy_module + np_random = numpy_module.random + for n, v in self._originals.items(): + setattr(np_random, n, v) + + self._numpy_module = None + self._originals = {} + self._patched = [] + self._is_patched = False + + def is_patched(self): + return self._is_patched + + def patched_names(self): + """ + Returns list of names that were actually patched. + """ + return list(self._patched) + + +_tls = threading_local() + + +def _is_tls_initialized(): + return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True) + + +def _initialize_tls(): + _tls.patch = patch() + _tls.initialized = True + + +def monkey_patch(numpy_module=None, names=None, strict=False): + """ + Enables using mkl_random in the given NumPy module by patching `numpy.random`. + + Examples + -------- + >>> import numpy as np + >>> import mkl_random + >>> mkl_random.is_patched() + False + >>> mkl_random.monkey_patch(np) + >>> mkl_random.is_patched() + True + >>> mkl_random.restore() + >>> mkl_random.is_patched() + False + """ + if not _is_tls_initialized(): + _initialize_tls() + _tls.patch.do_patch(numpy_module=numpy_module, names=names, strict=bool(strict)) + + +def use_in_numpy(numpy_module=None, names=None, strict=False): + """ + Backward-compatible alias for monkey_patch(). + """ + monkey_patch(numpy_module=numpy_module, names=names, strict=strict) + + +def restore(): + """ + Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols. + """ + if not _is_tls_initialized(): + _initialize_tls() + _tls.patch.do_unpatch() + + +def is_patched(): + """ + Returns whether NumPy has been patched with mkl_random. + """ + if not _is_tls_initialized(): + _initialize_tls() + return bool(_tls.patch.is_patched()) + + +def patched_names(): + """ + Returns the names actually patched in `numpy.random`. + """ + if not _is_tls_initialized(): + _initialize_tls() + return _tls.patch.patched_names() + + +class mkl_random(ContextDecorator): + """ + Context manager and decorator to temporarily patch NumPy's `numpy.random`. + + Examples + -------- + >>> import numpy as np + >>> import mkl_random + >>> with mkl_random.mkl_random(): + ... x = np.random.normal(size=10) + """ + def __init__(self, numpy_module=None, names=None, strict=False): + self._numpy_module = numpy_module + self._names = names + self._strict = strict + + def __enter__(self): + monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict) + return self + + def __exit__(self, *exc): + restore() + return False diff --git a/mkl_random/tests/test_patch.py b/mkl_random/tests/test_patch.py new file mode 100644 index 0000000..3dabea1 --- /dev/null +++ b/mkl_random/tests/test_patch.py @@ -0,0 +1,95 @@ +import numpy as np +import mkl_random +import pytest + +def test_is_patched(): + """ + Test that is_patched() returns correct status. + """ + assert not mkl_random.is_patched() + mkl_random.monkey_patch(np) + assert mkl_random.is_patched() + mkl_random.restore() + assert not mkl_random.is_patched() + +def test_monkey_patch_and_restore(): + """ + Test that monkey_patch replaces and restore brings back original functions. + """ + # Store original functions + orig_normal = np.random.normal + orig_randint = np.random.randint + orig_RandomState = np.random.RandomState + + try: + mkl_random.monkey_patch(np) + + # Check that functions are now different objects + assert np.random.normal is not orig_normal + assert np.random.randint is not orig_randint + assert np.random.RandomState is not orig_RandomState + + # Check that they are from mkl_random + assert np.random.normal is mkl_random.mklrand.normal + assert np.random.RandomState is mkl_random.mklrand.RandomState + + finally: + mkl_random.restore() + + # Check that original functions are restored + assert mkl_random.is_patched() is False + assert np.random.normal is orig_normal + assert np.random.randint is orig_randint + assert np.random.RandomState is orig_RandomState + +def test_context_manager(): + """ + Test that the context manager patches and automatically restores. + """ + orig_uniform = np.random.uniform + assert not mkl_random.is_patched() + + with mkl_random.mkl_random(np): + assert mkl_random.is_patched() is True + assert np.random.uniform is not orig_uniform + # Smoke test inside context + arr = np.random.uniform(size=10) + assert arr.shape == (10,) + + assert not mkl_random.is_patched() + assert np.random.uniform is orig_uniform + +def test_patched_functions_callable(): + """ + Smoke test to ensure some patched functions can be called without error. + """ + mkl_random.monkey_patch(np) + try: + # These calls should now be routed to mkl_random's implementations + x = np.random.standard_normal(size=100) + assert x.shape == (100,) + + y = np.random.randint(0, 100, size=50) + assert y.shape == (50,) + assert np.all(y >= 0) and np.all(y < 100) + + st = np.random.RandomState(12345) + z = st.rand(10) + assert z.shape == (10,) + + finally: + mkl_random.restore() + +def test_patched_names(): + """ + Test that patched_names() returns a list of patched symbols. + """ + try: + mkl_random.monkey_patch(np) + names = mkl_random.patched_names() + assert isinstance(names, list) + assert len(names) > 0 + assert "normal" in names + assert "RandomState" in names + finally: + mkl_random.restore() diff --git a/setup.py b/setup.py index c47ebfb..70f83ea 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,14 @@ def extensions(): extra_compile_args = eca, define_macros=defs + [("NDEBUG", None)], language="c++" + ), + + Extension( + "mkl_random._patch", + sources=[join("mkl_random", "src", "_patch.pyx")], + include_dirs=[np.get_include()], + define_macros=defs + [("NDEBUG", None)], + language="c", ) ] From 38fe23d85469cd3119b90c616fd95416352924eb Mon Sep 17 00:00:00 2001 From: "Harlow, Jordan" Date: Thu, 5 Mar 2026 05:37:59 -0700 Subject: [PATCH 2/5] fix: patching to match mkl_fft, lint, and review --- .pylintrc | 5 + mkl_random/__init__.py | 9 +- mkl_random/_patch_numpy.py | 280 +++++++++++++++++++++++++++++++++ mkl_random/src/_patch.pyx | 275 -------------------------------- mkl_random/tests/test_patch.py | 125 +++++++++++---- setup.py | 10 +- 6 files changed, 392 insertions(+), 312 deletions(-) create mode 100644 .pylintrc create mode 100644 mkl_random/_patch_numpy.py delete mode 100644 mkl_random/src/_patch.pyx diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..71f9704 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,5 @@ +[MASTER] +extension-pkg-allow-list=numpy,mkl_random.mklrand + +[TYPECHECK] +generated-members=RandomState,min,max diff --git a/mkl_random/__init__.py b/mkl_random/__init__.py index b96f203..d60e5c9 100644 --- a/mkl_random/__init__.py +++ b/mkl_random/__init__.py @@ -93,9 +93,16 @@ test = PytestTester(__name__) del PytestTester -from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random from mkl_random import interfaces +from ._patch_numpy import ( + is_patched, + mkl_random, + patch_numpy_random, + patched_names, + restore_numpy_random, +) + __all__ = [ "MKLRandomState", "RandomState", diff --git a/mkl_random/_patch_numpy.py b/mkl_random/_patch_numpy.py new file mode 100644 index 0000000..f5db908 --- /dev/null +++ b/mkl_random/_patch_numpy.py @@ -0,0 +1,280 @@ +# Copyright (c) 2019, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Define functions for patching NumPy with MKL-based NumPy interface.""" + +import warnings +from contextlib import ContextDecorator +from threading import Lock, local + +import numpy as _np + +import mkl_random.interfaces.numpy_random as _nrand + +_DEFAULT_NAMES = tuple(_nrand.__all__) + + +class _GlobalPatch: + def __init__(self): + self._lock = Lock() + self._patch_count = 0 + self._restore_dict = {} + self._patched_functions = tuple(_DEFAULT_NAMES) + self._numpy_module = None + self._requested_names = None + self._active_names = () + self._patched = () + self._tls = local() + + def _normalize_names(self, names): + if names is None: + names = _DEFAULT_NAMES + return tuple(names) + + def _validate_module(self, numpy_module): + if not hasattr(numpy_module, "random"): + raise TypeError( + "Expected a numpy-like module with a `.random` attribute." + ) + + def _register_func(self, name, func): + if name not in self._patched_functions: + raise ValueError(f"{name} not an mkl_random function.") + np_random = self._numpy_module.random + if name not in self._restore_dict: + self._restore_dict[name] = getattr(np_random, name) + setattr(np_random, name, func) + + def _restore_func(self, name, verbose=False): + if name not in self._patched_functions: + raise ValueError(f"{name} not an mkl_random function.") + try: + val = self._restore_dict[name] + except KeyError: + if verbose: + print(f"failed to restore {name}") + return + else: + if verbose: + print(f"found and restoring {name}...") + np_random = self._numpy_module.random + setattr(np_random, name, val) + + def _initialize_patch(self, numpy_module, names, strict): + self._validate_module(numpy_module) + np_random = numpy_module.random + missing = [] + patchable = [] + for name in names: + if name not in self._patched_functions: + missing.append(name) + continue + if not hasattr(np_random, name) or not hasattr(_nrand, name): + missing.append(name) + continue + patchable.append(name) + + if strict and missing: + raise AttributeError( + "Could not patch these names (missing on numpy.random or " + "mkl_random.interfaces.numpy_random): " + + ", ".join(str(x) for x in missing) + ) + + self._numpy_module = numpy_module + self._requested_names = names + self._active_names = tuple(patchable) + self._patched = tuple(patchable) + + def do_patch( + self, + numpy_module=None, + names=None, + strict=False, + verbose=False, + ): + if numpy_module is None: + numpy_module = _np + names = self._normalize_names(names) + strict = bool(strict) + + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if self._patch_count == 0: + self._initialize_patch(numpy_module, names, strict) + if verbose: + print( + "Now patching NumPy random submodule with mkl_random " + "NumPy interface." + ) + print( + "Please direct bug reports to " + "https://github.com/IntelPython/mkl_random" + ) + for name in self._active_names: + self._register_func(name, getattr(_nrand, name)) + else: + if self._numpy_module is not numpy_module: + raise RuntimeError( + "Already patched a different numpy module; " + "call restore() first." + ) + if names != self._requested_names: + raise RuntimeError( + "Already patched with a different names set; " + "call restore() first." + ) + self._patch_count += 1 + self._tls.local_count = local_count + 1 + + def do_restore(self, verbose=False): + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if local_count <= 0: + if verbose: + warnings.warn( + "Warning: restore_numpy_random called more times than " + "patch_numpy_random in this thread.", + stacklevel=2, + ) + return + + self._tls.local_count = local_count - 1 + self._patch_count -= 1 + if self._patch_count == 0: + if verbose: + print("Now restoring original NumPy random submodule.") + for name in tuple(self._restore_dict): + self._restore_func(name, verbose=verbose) + self._restore_dict.clear() + self._numpy_module = None + self._requested_names = None + self._active_names = () + self._patched = () + + def is_patched(self): + with self._lock: + return self._patch_count > 0 + + def patched_names(self): + with self._lock: + return list(self._patched) + + +_patch = _GlobalPatch() + + +def patch_numpy_random( + numpy_module=None, + names=None, + strict=False, + verbose=False, +): + """ + Patch NumPy's random submodule with mkl_random's NumPy interface. + + Parameters + ---------- + numpy_module : module, optional + NumPy-like module to patch. Defaults to imported NumPy. + names : iterable[str], optional + Attributes under `numpy_module.random` to patch. + strict : bool, optional + Raise if any requested symbol cannot be patched. + verbose : bool, optional + Print messages when starting the patching process. + + Examples + -------- + >>> import numpy as np + >>> import mkl_random + >>> mkl_random.is_patched() + False + >>> mkl_random.patch_numpy_random(np) + >>> mkl_random.is_patched() + True + >>> mkl_random.restore() + >>> mkl_random.is_patched() + False + """ + _patch.do_patch( + numpy_module=numpy_module, + names=names, + strict=bool(strict), + verbose=bool(verbose), + ) + + +def restore_numpy_random(verbose=False): + """ + Restore NumPy's random submodule to its original implementations. + + Parameters + ---------- + verbose : bool, optional + Print message when starting restoration process. + """ + _patch.do_restore(verbose=bool(verbose)) + + +def is_patched(): + """Return whether NumPy has been patched with mkl_random.""" + return _patch.is_patched() + + +def patched_names(): + """Return names actually patched in `numpy.random`.""" + return _patch.patched_names() + + +class mkl_random(ContextDecorator): + """ + Context manager and decorator to temporarily patch NumPy random submodule + with MKL-based implementations. + + Examples + -------- + >>> import numpy as np + >>> import mkl_random + >>> with mkl_random.mkl_random(np): + ... x = np.random.normal(size=10) + """ + + def __init__(self, numpy_module=None, names=None, strict=False): + self._numpy_module = numpy_module + self._names = names + self._strict = strict + + def __enter__(self): + patch_numpy_random( + numpy_module=self._numpy_module, + names=self._names, + strict=self._strict, + ) + return self + + def __exit__(self, *exc): + restore_numpy_random() + return False diff --git a/mkl_random/src/_patch.pyx b/mkl_random/src/_patch.pyx deleted file mode 100644 index 6c39ff3..0000000 --- a/mkl_random/src/_patch.pyx +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright (c) 2019, Intel Corporation -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of Intel Corporation nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# distutils: language = c -# cython: language_level=3 - -""" -Patch NumPy's `numpy.random` symbols to use mkl_random implementations. - -This is attribute-level monkey patching. It can replace legacy APIs like -`numpy.random.RandomState` and global distribution functions, but it does not -replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully -compatible replacements. -""" - -from threading import local as threading_local -from contextlib import ContextDecorator - -import numpy as _np -from . import mklrand as _mr - - -cdef tuple _DEFAULT_NAMES = ( - # Legacy seeding / state - "seed", - "get_state", - "set_state", - "RandomState", - - # Common global sampling helpers - "random", - "random_sample", - "sample", - "rand", - "randn", - "bytes", - - # Integers - "randint", - - # Common distributions (only patched if present on both sides) - "standard_normal", - "normal", - "uniform", - "exponential", - "gamma", - "beta", - "chisquare", - "f", - "lognormal", - "laplace", - "logistic", - "multivariate_normal", - "poisson", - "power", - "rayleigh", - "triangular", - "vonmises", - "wald", - "weibull", - "zipf", - - # Permutations / choices - "choice", - "permutation", - "shuffle", -) - - -cdef class patch: - cdef bint _is_patched - cdef object _numpy_module - cdef object _originals # dict: name -> original object - cdef object _patched # list of names actually patched - - def __cinit__(self): - self._is_patched = False - self._numpy_module = None - self._originals = {} - self._patched = [] - - def do_patch(self, numpy_module=None, names=None, bint strict=False): - """ - Patch the given numpy module (default: imported numpy) in-place. - - Parameters - ---------- - numpy_module : module, optional - The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`). - names : iterable[str], optional - Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES. - strict : bool - If True, raise if any requested symbol cannot be patched. - """ - if numpy_module is None: - numpy_module = _np - if names is None: - names = _DEFAULT_NAMES - - if not hasattr(numpy_module, "random"): - raise TypeError("Expected a numpy-like module with a `.random` attribute.") - - # If already patched, only allow idempotent re-entry for the same numpy module. - if self._is_patched: - if self._numpy_module is numpy_module: - return - raise RuntimeError("Already patched a different numpy module; call restore() first.") - - np_random = numpy_module.random - - originals = {} - patched = [] - missing = [] - - for name in names: - if not hasattr(np_random, name) or not hasattr(_mr, name): - missing.append(name) - continue - originals[name] = getattr(np_random, name) - setattr(np_random, name, getattr(_mr, name)) - patched.append(name) - - if strict and missing: - # revert partial patch before raising - for n, v in originals.items(): - setattr(np_random, n, v) - raise AttributeError( - "Could not patch these names (missing on numpy.random or mkl_random.mklrand): " - + ", ".join([str(x) for x in missing]) - ) - - self._numpy_module = numpy_module - self._originals = originals - self._patched = patched - self._is_patched = True - - def do_unpatch(self): - """ - Restore the previously patched numpy module. - """ - if not self._is_patched: - return - numpy_module = self._numpy_module - np_random = numpy_module.random - for n, v in self._originals.items(): - setattr(np_random, n, v) - - self._numpy_module = None - self._originals = {} - self._patched = [] - self._is_patched = False - - def is_patched(self): - return self._is_patched - - def patched_names(self): - """ - Returns list of names that were actually patched. - """ - return list(self._patched) - - -_tls = threading_local() - - -def _is_tls_initialized(): - return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True) - - -def _initialize_tls(): - _tls.patch = patch() - _tls.initialized = True - - -def monkey_patch(numpy_module=None, names=None, strict=False): - """ - Enables using mkl_random in the given NumPy module by patching `numpy.random`. - - Examples - -------- - >>> import numpy as np - >>> import mkl_random - >>> mkl_random.is_patched() - False - >>> mkl_random.monkey_patch(np) - >>> mkl_random.is_patched() - True - >>> mkl_random.restore() - >>> mkl_random.is_patched() - False - """ - if not _is_tls_initialized(): - _initialize_tls() - _tls.patch.do_patch(numpy_module=numpy_module, names=names, strict=bool(strict)) - - -def use_in_numpy(numpy_module=None, names=None, strict=False): - """ - Backward-compatible alias for monkey_patch(). - """ - monkey_patch(numpy_module=numpy_module, names=names, strict=strict) - - -def restore(): - """ - Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols. - """ - if not _is_tls_initialized(): - _initialize_tls() - _tls.patch.do_unpatch() - - -def is_patched(): - """ - Returns whether NumPy has been patched with mkl_random. - """ - if not _is_tls_initialized(): - _initialize_tls() - return bool(_tls.patch.is_patched()) - - -def patched_names(): - """ - Returns the names actually patched in `numpy.random`. - """ - if not _is_tls_initialized(): - _initialize_tls() - return _tls.patch.patched_names() - - -class mkl_random(ContextDecorator): - """ - Context manager and decorator to temporarily patch NumPy's `numpy.random`. - - Examples - -------- - >>> import numpy as np - >>> import mkl_random - >>> with mkl_random.mkl_random(): - ... x = np.random.normal(size=10) - """ - def __init__(self, numpy_module=None, names=None, strict=False): - self._numpy_module = numpy_module - self._names = names - self._strict = strict - - def __enter__(self): - monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict) - return self - - def __exit__(self, *exc): - restore() - return False diff --git a/mkl_random/tests/test_patch.py b/mkl_random/tests/test_patch.py index 3dabea1..fc2337a 100644 --- a/mkl_random/tests/test_patch.py +++ b/mkl_random/tests/test_patch.py @@ -1,28 +1,52 @@ +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + import numpy as np -import mkl_random import pytest +import mkl_random + + def test_is_patched(): - """ - Test that is_patched() returns correct status. - """ + """Test that is_patched() returns correct status.""" assert not mkl_random.is_patched() - mkl_random.monkey_patch(np) + mkl_random.patch_numpy_random(np) assert mkl_random.is_patched() - mkl_random.restore() + mkl_random.restore_numpy_random() assert not mkl_random.is_patched() -def test_monkey_patch_and_restore(): - """ - Test that monkey_patch replaces and restore brings back original functions. - """ + +def test_patch_and_restore(): + """Test patch replacement and restore of original functions.""" # Store original functions orig_normal = np.random.normal orig_randint = np.random.randint orig_RandomState = np.random.RandomState try: - mkl_random.monkey_patch(np) + mkl_random.patch_numpy_random(np) # Check that functions are now different objects assert np.random.normal is not orig_normal @@ -30,11 +54,11 @@ def test_monkey_patch_and_restore(): assert np.random.RandomState is not orig_RandomState # Check that they are from mkl_random - assert np.random.normal is mkl_random.mklrand.normal - assert np.random.RandomState is mkl_random.mklrand.RandomState + assert np.random.normal is mkl_random.normal + assert np.random.RandomState is mkl_random.RandomState finally: - mkl_random.restore() + mkl_random.restore_numpy_random() # Check that original functions are restored assert mkl_random.is_patched() is False @@ -42,10 +66,9 @@ def test_monkey_patch_and_restore(): assert np.random.randint is orig_randint assert np.random.RandomState is orig_RandomState + def test_context_manager(): - """ - Test that the context manager patches and automatically restores. - """ + """Test context manager patching and automatic restoration.""" orig_uniform = np.random.uniform assert not mkl_random.is_patched() @@ -59,11 +82,10 @@ def test_context_manager(): assert not mkl_random.is_patched() assert np.random.uniform is orig_uniform + def test_patched_functions_callable(): - """ - Smoke test to ensure some patched functions can be called without error. - """ - mkl_random.monkey_patch(np) + """Smoke test that patched functions are callable without errors.""" + mkl_random.patch_numpy_random(np) try: # These calls should now be routed to mkl_random's implementations x = np.random.standard_normal(size=100) @@ -78,18 +100,67 @@ def test_patched_functions_callable(): assert z.shape == (10,) finally: - mkl_random.restore() + mkl_random.restore_numpy_random() + def test_patched_names(): - """ - Test that patched_names() returns a list of patched symbols. - """ + """Test that patched_names() returns patched symbol names.""" try: - mkl_random.monkey_patch(np) + mkl_random.patch_numpy_random(np) names = mkl_random.patched_names() assert isinstance(names, list) assert len(names) > 0 assert "normal" in names assert "RandomState" in names finally: - mkl_random.restore() + mkl_random.restore_numpy_random() + + +def test_patch_strict_raises_attribute_error(): + """Test strict mode raises AttributeError for missing patch names.""" + # Attempt to patch a clearly non-existent symbol in strict mode. + with pytest.raises(AttributeError): + mkl_random.patch_numpy_random( + np, + strict=True, + names=["nonexistent_symbol"], + ) + + +def test_patch_redundant_patching(): + orig_normal = np.random.normal + assert not mkl_random.is_patched() + + try: + mkl_random.patch_numpy_random(np) + mkl_random.patch_numpy_random(np) + assert mkl_random.is_patched() + assert np.random.normal is mkl_random.normal + mkl_random.restore_numpy_random() + assert mkl_random.is_patched() + assert np.random.normal is mkl_random.normal + mkl_random.restore_numpy_random() + assert not mkl_random.is_patched() + assert np.random.normal is orig_normal + finally: + while mkl_random.is_patched(): + mkl_random.restore_numpy_random() + + +def test_patch_reentrant(): + orig_uniform = np.random.uniform + assert not mkl_random.is_patched() + + with mkl_random.mkl_random(np): + assert mkl_random.is_patched() + assert np.random.uniform is not orig_uniform + + with mkl_random.mkl_random(np): + assert mkl_random.is_patched() + assert np.random.uniform is not orig_uniform + + assert mkl_random.is_patched() + assert np.random.uniform is not orig_uniform + + assert not mkl_random.is_patched() + assert np.random.uniform is orig_uniform diff --git a/setup.py b/setup.py index 20b4b31..0b90bf9 100644 --- a/setup.py +++ b/setup.py @@ -91,16 +91,8 @@ def extensions(): library_dirs=lib_dirs, extra_compile_args=eca, define_macros=defs + [("NDEBUG", None)], - language="c++" + language="c++", ), - - Extension( - "mkl_random._patch", - sources=[join("mkl_random", "src", "_patch.pyx")], - include_dirs=[np.get_include()], - define_macros=defs + [("NDEBUG", None)], - language="c", - ) ] return exts From 0a94878e2c0d10024569b0cf364255ed3acfa300 Mon Sep 17 00:00:00 2001 From: "Harlow, Jordan" Date: Fri, 6 Mar 2026 07:33:40 -0700 Subject: [PATCH 3/5] fix: testing --- mkl_random/tests/test_patch.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mkl_random/tests/test_patch.py b/mkl_random/tests/test_patch.py index fc2337a..b684bba 100644 --- a/mkl_random/tests/test_patch.py +++ b/mkl_random/tests/test_patch.py @@ -27,6 +27,7 @@ import pytest import mkl_random +import mkl_random.interfaces.numpy_random as _nrand def test_is_patched(): @@ -53,9 +54,9 @@ def test_patch_and_restore(): assert np.random.randint is not orig_randint assert np.random.RandomState is not orig_RandomState - # Check that they are from mkl_random - assert np.random.normal is mkl_random.normal - assert np.random.RandomState is mkl_random.RandomState + # Check that they are from mkl_random interface module + assert np.random.normal is _nrand.normal + assert np.random.RandomState is _nrand.RandomState finally: mkl_random.restore_numpy_random() @@ -135,10 +136,10 @@ def test_patch_redundant_patching(): mkl_random.patch_numpy_random(np) mkl_random.patch_numpy_random(np) assert mkl_random.is_patched() - assert np.random.normal is mkl_random.normal + assert np.random.normal is _nrand.normal mkl_random.restore_numpy_random() assert mkl_random.is_patched() - assert np.random.normal is mkl_random.normal + assert np.random.normal is _nrand.normal mkl_random.restore_numpy_random() assert not mkl_random.is_patched() assert np.random.normal is orig_normal From 8a8b942eed5665e1bc890dd3b09e7540331ba19f Mon Sep 17 00:00:00 2001 From: "Harlow, Jordan" Date: Fri, 6 Mar 2026 07:43:49 -0700 Subject: [PATCH 4/5] chore: update CHANGELOG --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2973996..09e4f04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [dev] (MM/DD/YYYY) ### Added +* Added `mkl_random` patching for NumPy, with `mkl_random` context manager, `is_patched` query, and `patch_numpy_random` and `restore_numpy_random` calls to replace `numpy.random` calls with calls from `mkl_random.interfaces.numpy_random` [gh-90](https://github.com/IntelPython/mkl_random/pull/90) + * Added `mkl_random.interfaces` with `mkl_random.interfaces.numpy_random` interface, which aliases `mkl_random` functionality to more strictly adhere to NumPy's API (i.e., drops arguments and functions which are not part of standard NumPy) [gh-92](https://github.com/IntelPython/mkl_random/pull/92) ### Removed From f1881c6a71ce3a7e62c05d855b25ae4c808d97ae Mon Sep 17 00:00:00 2001 From: "Harlow, Jordan" Date: Fri, 6 Mar 2026 13:40:19 -0700 Subject: [PATCH 5/5] task: review fixes --- .pylintrc | 5 ----- mkl_random/__init__.py | 4 ++++ mkl_random/tests/test_patch.py | 23 +++++++++++++++++++++++ pyproject.toml | 6 ++++++ 4 files changed, 33 insertions(+), 5 deletions(-) delete mode 100644 .pylintrc diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 71f9704..0000000 --- a/.pylintrc +++ /dev/null @@ -1,5 +0,0 @@ -[MASTER] -extension-pkg-allow-list=numpy,mkl_random.mklrand - -[TYPECHECK] -generated-members=RandomState,min,max diff --git a/mkl_random/__init__.py b/mkl_random/__init__.py index d60e5c9..67238f5 100644 --- a/mkl_random/__init__.py +++ b/mkl_random/__init__.py @@ -155,6 +155,10 @@ "shuffle", "permutation", "interfaces", + "patch_numpy_random", + "restore_numpy_random", + "is_patched", + "patched_names", ] del _init_helper diff --git a/mkl_random/tests/test_patch.py b/mkl_random/tests/test_patch.py index b684bba..58cf71c 100644 --- a/mkl_random/tests/test_patch.py +++ b/mkl_random/tests/test_patch.py @@ -68,6 +68,29 @@ def test_patch_and_restore(): assert np.random.RandomState is orig_RandomState +def test_patch_with_limited_names(): + """Test patching only selected functions via names keyword.""" + orig_normal = np.random.normal + orig_randint = np.random.randint + assert not mkl_random.is_patched() + + try: + mkl_random.patch_numpy_random(np, names=["normal"]) + assert mkl_random.is_patched() + assert np.random.normal is _nrand.normal + assert np.random.randint is orig_randint + + names = mkl_random.patched_names() + assert "normal" in names + assert "randint" not in names + finally: + mkl_random.restore_numpy_random() + + assert not mkl_random.is_patched() + assert np.random.normal is orig_normal + assert np.random.randint is orig_randint + + def test_context_manager(): """Test context manager patching and automatic restoration.""" orig_uniform = np.random.uniform diff --git a/pyproject.toml b/pyproject.toml index 3352468..6c136fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,12 @@ line_length = 80 multi_line_output = 3 use_parentheses = true +[tool.pylint.main] +extension-pkg-allow-list = ["numpy", "mkl_random.mklrand"] + +[tool.pylint.typecheck] +generated-members = ["RandomState", "min", "max"] + [tool.setuptools] include-package-data = true