Address Lora shortcomings#28801
Conversation
export_adapter wrote tensor.DataRaw() for tensor.SizeInBytes() bytes regardless of element type. For tensor(string) parameters this copied the std::string object representation - heap pointers and SSO padding - directly into Parameter.raw_data, leaking runtime addresses (ASLR bypass) and uninitialized bytes, and producing an adapter that cannot be safely loaded (reinterpreting the saved bytes as std::string objects is undefined behavior). Reject STRING element type with a clear error and defer opening the output file until after validation/serialization so a rejected export does not leave a stray empty file behind. Test: test_adapter_export_rejects_string_tensors asserts export_adapter raises on tensor(string) parameters and leaves no file on disk.
There was a problem hiding this comment.
Pull request overview
This PR hardens LoRA adapter handling across the core C++ implementation and Python bindings, focusing on stronger exception-safety during adapter loading, safer object lifetimes in Python, and safer adapter export behavior.
Changes:
- Refactors
LoraAdapterloading/memory-mapping to construct validated state locally before committing it to the object (strong exception guarantee). - Fixes Python binding lifetime hazards by tying the returned
parametersdict to its owningAdapterFormatinstance viapy::keep_alive. - Rejects exporting string tensors from Python adapter export, and adds Python regression tests for both the keep-alive behavior and string-tensor rejection.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
onnxruntime/core/session/lora_adapters.cc |
Refactors Load/MemoryMap to build a fresh params map locally via BuildParamsValues() before committing state. |
onnxruntime/core/session/lora_adapters.h |
Replaces InitializeParamsValues() with side-effect-free BuildParamsValues() to support strong exception guarantees. |
onnxruntime/python/onnxruntime_pybind_lora.cc |
Fixes loaded_adapter_ typo, adds keep_alive for parameters, and rejects STRING tensors during export. |
onnxruntime/test/python/onnxruntime_test_python.py |
Adds regression tests for Python keep-alive and for rejecting string-tensor adapter export without creating a file. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
pybind11's def_property() does not accept keep_alive directly (static_assert fires). Wrap the getter in py::cpp_function so the policy can be attached, restoring the keep-alive behavior intended by 245ffaf. Also compute the serialized adapter span before opening the output file, per PR review: a failure inside FinishWithSpan should not leave a stray empty file behind.
The C-level pybind11 keep_alive on parameters tied the returned dict to the C AdapterFormat, but AdapterFormat.get_parameters() in the Python wrapper rebuilt a fresh dict of OrtValue wrappers and dropped the C dict, losing the link. The user-facing pattern
params = ort.AdapterFormat.read_adapter(p).get_parameters()
would still leave each OrtValue dangling once the temporary AdapterFormat was destroyed. Pin the C AdapterFormat on every OrtValue handed back, so callers who keep any of them keep the backing adapter too.
Also fix the keep-alive regression test, which used a non-existent .parameters attribute on the Python wrapper (would have failed with AttributeError before exercising the UAF), and add a stronger assertion that an individual OrtValue stays valid after the dict is dropped.
ortvalue_from_numpy_with_onnx_type rejected the string numpy array (it validates numpy element size against the std::string object size), so the test errored out before reaching export_adapter. There is no public Python API that constructs a string-typed OrtValue directly; the only working path is to obtain one as a session output. Build a tiny in-memory Constant model whose output is a string tensor and use that OrtValue to drive the export rejection check.
py::keep_alive<0,1> on a returned py::dict raises 'cannot create weak reference to dict' at runtime: pybind11 keep_alive takes a weakref to the patient (the returned dict), and Python dicts are not weak-referenceable. The C-level policy was therefore broken in two ways across the two pybind11 versions seen so far: a static_assert in the newer pybind11 (def_property does not accept keep_alive) and a runtime TypeError once routed through cpp_function.
Lifetime safety for the documented user-facing pattern
params = ort.AdapterFormat.read_adapter(p).get_parameters()
is provided by the Python wrapper AdapterFormat.get_parameters, which already pins the owning C AdapterFormat on every OrtValue it hands back. Document this in the C-level binding and remove the broken keep_alive.
… noexcept - Add ORT_ENFORCE guards in export_adapter's add_param lambda to verify the OrtValue is a tensor and resides on CPU before calling DataRaw(). Previously, passing a GPU tensor or non-tensor OrtValue would crash with a device-pointer dereference on the host. - Mark BufferHolder and MemMapHolder constructors and move operations noexcept so the 'commit cannot throw' claim in Load()/MemoryMap() is compiler-verifiable via std::variant's noexcept guarantees.
After read_adapter, calling the parameters setter followed by export_adapter would silently re-export the original read data because export_adapter prioritized loaded_adapter_ over the user-supplied parameters_ dict. Fix: clear loaded_adapter_ in the setter so that both the getter and export_adapter consistently use the user-supplied parameters after an explicit set. Add a regression test for the read -> modify -> export path.
…iter Move the design explanation into a class-level comment that explains: - WHY two sources: zero-copy views avoid duplicating adapter weights on memory-constrained devices. - WHY not a single cached dict: pybind11 instances lack tp_traverse, so self -> dict -> OrtValue (patient) -> self would be an un-collectable reference cycle. - HOW the setter bridges the two: clearing loaded_adapter_ on explicit set_parameters ensures consistent behavior. Trim verbose inline comments that repeated the same reasoning.
…tances An AdapterFormat instance now operates in one of two exclusive modes: - READ mode (from read_adapter): parameters are read-only zero-copy views. Calling set_parameters raises RuntimeError. - WRITE mode (default constructor): user sets parameters and exports. This eliminates the ambiguous two-source conditional in export_adapter and prevents silent data loss where set_parameters would be ignored. Users who need to re-export with different parameters should create a new instance. Update regression test to verify the read-only enforcement.
Eliminate loaded_adapter_ / parameters_ duality entirely. The parameters_ dict is now always the single authoritative source. On the read path, a heap-allocated LoraAdapter is transferred into a py::capsule. Each OrtValue view in the dict has the capsule pinned as its pybind11 patient. The reference graph is: PyAdapterFormatReaderWriter -> parameters_ dict -> OrtValues -> capsule -> LoraAdapter No edge points back to PyAdapterFormatReaderWriter, so there is no reference cycle (pybind11 instances lack tp_traverse). The capsule ref-counts keep the backing memory alive as long as any OrtValue is referenced in Python. Benefits: - Zero-copy: no data duplication on memory-constrained devices. - No reference cycles: capsule is an opaque Python object the GC handles. - Single source: getter, setter, and export_adapter all operate on parameters_. read -> modify -> export just works. - Simpler code: removed conditional branches in getter and export_adapter.
tianleiwu
left a comment
There was a problem hiding this comment.
Overall this is a solid hardening of LoRA adapter handling. The strong-exception-guarantee refactor in LoraAdapter::Load/MemoryMap is correct: BuildParamsValues is genuinely side-effect-free, the commit section runs only after all throwing work, and the noexcept move ctors on BufferHolder/MemMapHolder back the "commit cannot throw" claim. The aliasing reasoning is also right — moving std::vector / MappedMemoryPtr preserves the data pointer, so OrtValues built over the locals stay valid after the move.
On the Python side, the py::capsule + reference_internal pattern correctly pins the owning LoraAdapter onto every returned OrtValue (keep-alive patient), so the dict and any individual value independently keep the backing memory alive with no reference cycle. The export path's tensor / CPU-residency / STRING checks run before the file is opened, so rejected exports leave no stray file, and the new regression tests cover both keep-alive paths and string rejection well.
One minor robustness nit inline. Prior reviewer threads (file-open ordering, CPU-tensor enforcement, test AttributeError, single-source semantics) look addressed on this head.
Construct py::capsule while unique_ptr still owns the LoraAdapter. If capsule allocation throws (e.g. bad_alloc), unique_ptr's destructor still cleans up. Only release() after capsule is successfully constructed.
| finally: | ||
| os.remove(file_path) | ||
|
|
| return std::make_unique<PyAdapterFormatReaderWriter>( | ||
| format_version, adapter_version, model_version, std::move(params)); | ||
| }, | ||
| R"pbdoc(The function returns an instance of the class that contains a dictionary of name -> numpy arrays)pbdoc"); |
This pull request significantly improves the safety, correctness, and memory management of LoRA adapter handling in ONNX Runtime, especially around Python bindings and adapter file export/import. The main focus is on ensuring strong exception safety, preventing use-after-free bugs by improving object lifetimes, and rejecting unsupported tensor types during export. Additionally, comprehensive regression tests are added to guard against these issues.
Key changes include:
Exception Safety & Parameter Handling
LoraAdapter::LoadandMemoryMapto provide a strong exception guarantee: all potentially-throwing operations are performed using local variables before committing to the object's state, ensuring no partial updates occur on failure. The newBuildParamsValuesmethod builds the parameter map without side effects, replacing the oldInitializeParamsValues. [1] [2] [3] [4] [5]Python Bindings & Memory Management
OrtValuereturned from adapter parameter getters is pinned to its owning C++ adapter object via pybind11'skeep_alivemechanism. This prevents use-after-free errors if the parentAdapterFormatobject is dropped while references to its parameters remain. The getter now builds the parameter dictionary on demand and avoids reference cycles that would leak memory. [1] [2] [3]Adapter Export Robustness
Clean-up & Consistency
PyAdapterFormatReaderWriterclass, ensuring that its internal state is only populated as appropriate for read or write operations, and removed unnecessary parameter passing. [1] [2]Regression Tests
These changes collectively make adapter handling safer and more robust, especially when interacting with Python, and add critical safeguards against subtle memory and serialization bugs.