- No functorch transforms active — we take the C++
super().apply()path - GradMode is enabled — so
is_executable = True, the autograd graph is built - Profiler is off — skip profiler input shape recording
- JIT tracing is not active —
_trace_pre_record/_trace_post_recordare noops dummyhasrequires_grad=True— it was created withtorch.empty((0,), requires_grad=True)on line 1615, soany_variable_requires_gradreturnsTrue_checkpoint_hookis NOT yet active — this.apply()call is before thewith _checkpoint_hook(new_frame)block (line 1623). SoSavedTensorDefaultHooks::is_enabled()returnsFalseat this point (the checkpoint's pack/unpack hooks aren't installed yet for thisapplycall).argscontains a mix of tensors and non-tensors (kwargs is a dict,*argsare the original user function args)
@classmethod
def apply(cls, *args, **kwargs):Called as _NoopSaveInputs.apply(dummy, kwargs, *args). No **kwargs are passed to .apply() itself, so kwargs={} in the apply signature (not to be confused with the kwargs positional arg).
1a. _is_setup_context_defined(cls.setup_context) → True
_NoopSaveInputsoverridessetup_context(line 801), so it's not the base_SingleLevelFunction.setup_context.
1b. bind_default_args(cls.forward, *args, **kwargs) is called.
- This binds the positional args
(dummy, kwargs, *args)toforward's signature:def forward(*args). - Result:
args = (dummy, kwargs, arg0, arg1, ...)
1c. torch._C._are_functorch_transforms_active() → False (assumption 1)
1d. _functorch.utils.unwrap_dead_wrappers(args) — iterates through args, for each tensor checks if it's a dead wrapper (from functorch), and unwraps. In our case nothing changes.
1e. super().apply(*args) — calls _C._FunctionBase.apply, which is THPFunction_apply in C++.
2a. unpack_input<false>(inputs) (line 1335)
Iterates over the Python tuple (dummy, kwargs, arg0, arg1, ...):
-
For each element, checks
THPVariable_Check(arg):dummy→ is a tensor,requires_grad=True→ added toinput_vars,needs_input_grad[0] = Truekwargs→ is a dict, not a tensor →is_variable_input[1] = false,needs_input_grad[1] = False- For each
arg_i: if it's a tensor, added toinput_varswith itsrequires_gradflag; otherwise marked as non-variable.
-
is_executable = GradMode::is_enabled() && any_variable_requires_grad(input_vars)→True && True=Truedummyrequires grad, so this is true.
-
next_edges = collect_next_edges(input_vars):- For
dummy(leaf, requires_grad): edge =(AccumulateGrad for dummy, 0) - For each tensor arg that requires_grad: edge to its grad_fn
- For each tensor arg that does NOT require_grad: edge to
Edge()(empty)
- For
Global state read: GradMode::is_enabled() (thread-local)
2b. RECORD_FUNCTION(...) (line 1341) — records the function name "_NoopSaveInputs" for profiling. Profiler is off, so this is essentially a noop.
2c. Functorch TLS check (line 1346) — functorch_tls is null (no transforms active), skipped.
2d. Create backward node (lines 1357-1367)
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
THPFunction* ctx = (THPFunction*)ctx_obj.get();
auto cdata = std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode);
ctx->cdata = cdata;_backward_clswas created byFunctionMeta.__init__(function.py:340): a dynamically-created class_NoopSaveInputsBackwardthat inheritsBackwardCFunctionwith_forward_cls = _NoopSaveInputs.- Instantiating it creates a
THPFunction(the Python-sidectx). - A
PyNodewraps it — this is the C++ autogradNodethat will be stored in the graph.
Global state updated:
at::sequence_numberis incremented (viapeek()earlier + the PyNode constructor)- A new
PyNodeis allocated on the heap, connected to the autograd graph
2e. Wire the graph (lines 1373-1388)
cdata->set_next_edges(std::move(input_info.next_edges));
ctx->needs_input_grad = input_info.needs_input_grad.release();
ctx->is_variable_input = std::move(input_info.is_variable_input);The PyNode's next_edges now point to:
AccumulateGradfordummygrad_fnedges for anyrequires_grad=Truetensor args- Empty edges for non-tensor or non-requires_grad args
Also reads clear_saved_tensors_on_access from the class (default False).
2f. Call forward (lines 1407-1444)
Guards: AutoGradMode(false), AutoFwGradMode(false) — disables grad mode and forward-grad mode for the duration of forward.
Global state updated: Thread-local GradMode set to false, restored after.
Since setup_context is overridden (overridden_setup_context = True):
-
Call
_NoopSaveInputs.forward(*args)→return torch.empty((0,))(line 798)- This runs under
no_gradso the resulting tensor has no grad_fn. - Returns a 0-sized tensor (the "noop" output).
- This runs under
-
Call
_NoopSaveInputs.setup_context(ctx, inputs, output)(line 801):
@staticmethod
def setup_context(ctx, inputs, output):
tensor_indices, tensors = zip(
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)], strict=False
)
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
def get_args(saved_tensors):
ret = [
saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
for i, o in enumerate(args)
]
return ret[1:] # skip the dummy
ctx.get_args = get_args
ctx.save_for_backward(*tensors)inputs=(dummy, kwargs, arg0, arg1, ...)as passed to forward- Separates tensors from non-tensors.
dummyis a tensor, plus whatever user args are tensors. - Creates a closure
get_argsthat reconstructs the full arg list fromctx.saved_tensors. ctx.save_for_backward(*tensors)— this is the critical call.
This is implemented as THPFunction_save_for_backward in python_function.cpp. It stores the tensors into ctx->to_save. The actual SavedVariable creation happens later in _save_variables (step 5b).
Global state updated: ctx->to_save is set to a tuple of the tensor references.
Back in THPFunction_apply, after forward + setup_context return.
4a. ensure_tuple(raw_output) — the forward returned a single tensor torch.empty((0,)), so it wraps it in a 1-tuple and returns unpack_output = True (meaning we'll unwrap it at the end).
4b. cdata->clear_input_metadata() — clears the input metadata on the PyNode (no longer needed after setup).
4c. Record input_info (lines 1161-1167) — since is_executable = True:
grad_fn->input_infostores type/device/size of each variable input (for backward shape checking).
4d. _get_tensors_to_save(...) (line 1171)
Since overridden_setup_context = True and is_executable = True:
- Iterates
ctx->to_save(the tensors fromsave_for_backward). - Each tensor is added to both
tensors_to_save(for creatingSavedVariables) andto_save_if_setup_context(a set ofTensorImpl*pointers, used later in_wrap_outputsto detect if an output is being saved).
5a. _wrap_outputs (python_function.cpp:638 → custom_function.cpp:445)
The output is the single torch.empty((0,)) tensor.
In _process_backward_mode_ad:
is_input = false(the output is a freshly-created tensor, not one of the inputs)is_modified = false(nothing was marked dirty)is_differentiable = True(cdata exists, not in non_differentiable set, float type is differentiable)- So it takes the
elsebranch inset_history(line 349):This sets theimpl::set_gradient_edge(var, {cdata, output_nr});grad_fnof the output tensor to thePyNodewe created, withoutput_nr = 0.
Global state updated: The output tensor now has grad_fn = PyNode(_NoopSaveInputsBackward). This connects it to the autograd graph.
The output tensor's output_info is also recorded on the ctx.
5b. _save_variables (python_function.cpp:852)
For each tensor in tensors_to_save:
- Checks if it's an output (by comparing TensorImpl pointers against the output set).
dummyis an input, not the output →is_output = false- User tensor args are inputs, not the output →
is_output = false
- Creates
SavedVariable(tensor, is_output=false):
For each tensor being saved:
SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_inplace_on_view)With is_output = false, is_inplace_on_view = false:
6a. Checks !variable.is_inference() — OK for normal tensors.
6b. Records metadata:
saved_version_ = variable._version()— the version counter snapshotis_leaf_ = variable.is_leaf()is_output_ = false
6c. SavedTensorDefaultHooks::is_enabled() → False (assumption 6 — the _checkpoint_hook context manager is not yet active at this point).
So maybe_hooks = nullptr. The hooks branch is not taken.
6d. Since !is_output || is_leaf_:
- Since
is_output = false,!is_output = true, so all saved variables here go through thesaved_original_ = truepath. The tensor is stored directly indata_. No copy is made.
Global state updated: self->saved_variables vector now contains SavedVariable objects holding direct references to the tensors.
Back in process_outputs:
unpack_output = True→ unwrap the 1-tuple → return the single output tensor.
Back in THPFunction_apply → returns the output tensor to Python.
Back in Function.apply → returns it to the caller.
The result: new_frame.input_saver is now the torch.empty((0,)) tensor, but with grad_fn set to the _NoopSaveInputsBackward node. This node holds all the original inputs as SavedVariables in its saved_variables list.
| State | Change |
|---|---|
at::sequence_number |
Incremented by 1 |
| Autograd graph | New PyNode(_NoopSaveInputsBackward) added, with edges to the grad_fns/AccumulateGrads of the input tensors |
Output tensor's grad_fn |
Set to the new PyNode |
GradMode (thread-local) |
Temporarily set to false during forward, then restored |
AutoFwGradMode (thread-local) |
Temporarily set to false during forward, then restored |
ctx.saved_variables |
Vector of SavedVariable objects holding direct refs to dummy + all tensor args |
ctx.get_args |
Closure stored on ctx for reconstructing the full arg list |
ctx.needs_input_grad |
Tuple of booleans |
ctx.is_variable_input |
Vector of booleans |
ctx.input_info |
Vector of input metadata (type/device/size) |
ctx.output_info |
Vector of output metadata |
The _NoopSaveInputs.apply(...) call happens before _checkpoint_hook(new_frame) is entered (line 1616 vs 1623). This means the saved tensor default hooks are not yet installed. So the tensors saved via ctx.save_for_backward are stored directly (not packed via checkpoint's pack_hook).
Later, when _checkpoint_hook is entered (line 1623), the user's forward function runs. Any save_for_backward calls during that forward will go through the checkpoint pack/unpack hooks. But the inputs saved by _NoopSaveInputs are saved without hooks — they're kept as direct tensor references.
During backward, when unpack_hook is called and needs to recompute, it accesses frame.input_saver.grad_fn (which is the _NoopSaveInputsBackward node), calls ctx.saved_tensors to retrieve the original inputs, and passes them to recompute_fn.