eigh_py: reject output-object deferral in the correctness check#161
Open
robobryce wants to merge 1 commit into
Open
eigh_py: reject output-object deferral in the correctness check#161robobryce wants to merge 1 commit into
robobryce wants to merge 1 commit into
Conversation
A submission can return cheap placeholder tensors from custom_kernel (timed) and defer the real eigendecomposition into check_implementation (untimed): the checker promotes the outputs to FP64 with value.double(), and if the output is a torch.Tensor subclass with __torch_function__ — or even a plain tensor with a per-instance .double()/.detach() override — that promotion runs the genuine solve outside the timed region. Confirmed live on the eigh B200 leaderboard (fabricated ~17 us, all tests passing). Two changes to reference.py: 1. Require an EXACT plain torch.Tensor output (type(value) is torch.Tensor), not merely isinstance() — the latter admits any subclass. 2. Promote to FP64 through an override-proof path (_as_plain_fp64): strip to a base-class view and call the UNBOUND torch.Tensor.detach / .as_subclass / .double, never the bound value.detach()/.double(). Bound calls dispatch through the object and a plain tensor still carries a per-instance __dict__, so an instance-attribute override would otherwise fire; unbound calls go through the type and bypass it. (The bound .detach() was the gap a padded 'plain-tensor .detach override' variant used to survive an earlier draft of this fix; the unbound form closes it.) Now the residual math always runs on data the submission cannot re-point, so the real work cannot be deferred out of the timed region.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The benchmark times
custom_kernelbut the correctness check runs outside the timed region. A submission can exploit that split: return cheap placeholder tensors fromcustom_kernel(fast, timed) and do the real eigendecomposition lazily insidecheck_implementationwhen it promotes the outputs to FP64.check_implementationpromotes withvalue.double(). That dispatches through the output object, so:torch.Tensorsubclass with__torch_function__can intercept.double()and run the genuine solve there, andtorch.Tensorcarries a per-instance__dict__, so an instance-attribute override of.double()/.detach()does the same without being a subclass.Confirmed live on the B200
eighleaderboard: a__torch_function__deferral was accepted with a fabricated ~17 µs time, all tests passing.Fix (reference.py)
Require an exact plain tensor. Gate on
type(value) is torch.Tensor, notisinstance(...)— the latter admits any subclass.Promote through an override-proof path (
_as_plain_fp64): strip the output to a base-class view and use the unboundtorch.Tensor.detach/.as_subclass/.double, never the boundvalue.detach()/value.double(). Bound calls dispatch through the object (and hit a per-instance override); unbound calls go through the type and bypass it.The second point matters specifically because a bound
value.detach()is itself an interceptable call — a "plain-tensor.detachoverride" variant survives a fix that strips viavalue.detach(). Using unboundtorch.Tensor.detach(value)closes that.After this, the residual math always runs on data the submission cannot re-point, so real work can't be deferred out of the timed region. Honest kernels (which return plain FP32 tensors) are unaffected.