Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Created February 12, 2026 20:28
Show Gist options
  • Select an option

  • Save soulitzer/e1ae532029605afb25b37b54bfa43662 to your computer and use it in GitHub Desktop.

Select an option

Save soulitzer/e1ae532029605afb25b37b54bfa43662 to your computer and use it in GitHub Desktop.
Full execution trace of _NoopSaveInputs.apply(dummy, kwargs, *args) in PyTorch checkpoint

Full execution trace of _NoopSaveInputs.apply(dummy, kwargs, *args)

Assumptions made

  1. No functorch transforms active — we take the C++ super().apply() path
  2. GradMode is enabled — so is_executable = True, the autograd graph is built
  3. Profiler is off — skip profiler input shape recording
  4. JIT tracing is not active_trace_pre_record / _trace_post_record are noops
  5. dummy has requires_grad=True — it was created with torch.empty((0,), requires_grad=True) on line 1615, so any_variable_requires_grad returns True
  6. _checkpoint_hook is NOT yet active — this .apply() call is before the with _checkpoint_hook(new_frame) block (line 1623). So SavedTensorDefaultHooks::is_enabled() returns False at this point (the checkpoint's pack/unpack hooks aren't installed yet for this apply call).
  7. args contains a mix of tensors and non-tensors (kwargs is a dict, *args are the original user function args)

Step 1: Python Function.apply (torch/autograd/function.py:581)

@classmethod
def apply(cls, *args, **kwargs):

Called as _NoopSaveInputs.apply(dummy, kwargs, *args). No **kwargs are passed to .apply() itself, so kwargs={} in the apply signature (not to be confused with the kwargs positional arg).

1a. _is_setup_context_defined(cls.setup_context)True

  • _NoopSaveInputs overrides setup_context (line 801), so it's not the base _SingleLevelFunction.setup_context.

1b. bind_default_args(cls.forward, *args, **kwargs) is called.

  • This binds the positional args (dummy, kwargs, *args) to forward's signature: def forward(*args).
  • Result: args = (dummy, kwargs, arg0, arg1, ...)

1c. torch._C._are_functorch_transforms_active()False (assumption 1)

1d. _functorch.utils.unwrap_dead_wrappers(args) — iterates through args, for each tensor checks if it's a dead wrapper (from functorch), and unwraps. In our case nothing changes.

1e. super().apply(*args) — calls _C._FunctionBase.apply, which is THPFunction_apply in C++.

Step 2: C++ THPFunction_apply (torch/csrc/autograd/python_function.cpp:1330)

2a. unpack_input<false>(inputs) (line 1335)

Iterates over the Python tuple (dummy, kwargs, arg0, arg1, ...):

  • For each element, checks THPVariable_Check(arg):

    • dummyis a tensor, requires_grad=True → added to input_vars, needs_input_grad[0] = True
    • kwargsis a dict, not a tensoris_variable_input[1] = false, needs_input_grad[1] = False
    • For each arg_i: if it's a tensor, added to input_vars with its requires_grad flag; otherwise marked as non-variable.
  • is_executable = GradMode::is_enabled() && any_variable_requires_grad(input_vars)True && True = True

    • dummy requires grad, so this is true.
  • next_edges = collect_next_edges(input_vars):

    • For dummy (leaf, requires_grad): edge = (AccumulateGrad for dummy, 0)
    • For each tensor arg that requires_grad: edge to its grad_fn
    • For each tensor arg that does NOT require_grad: edge to Edge() (empty)

Global state read: GradMode::is_enabled() (thread-local)

2b. RECORD_FUNCTION(...) (line 1341) — records the function name "_NoopSaveInputs" for profiling. Profiler is off, so this is essentially a noop.

2c. Functorch TLS check (line 1346) — functorch_tls is null (no transforms active), skipped.

2d. Create backward node (lines 1357-1367)

THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
THPFunction* ctx = (THPFunction*)ctx_obj.get();
auto cdata = std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode);
ctx->cdata = cdata;
  • _backward_cls was created by FunctionMeta.__init__ (function.py:340): a dynamically-created class _NoopSaveInputsBackward that inherits BackwardCFunction with _forward_cls = _NoopSaveInputs.
  • Instantiating it creates a THPFunction (the Python-side ctx).
  • A PyNode wraps it — this is the C++ autograd Node that will be stored in the graph.

Global state updated:

  • at::sequence_number is incremented (via peek() earlier + the PyNode constructor)
  • A new PyNode is allocated on the heap, connected to the autograd graph

2e. Wire the graph (lines 1373-1388)

cdata->set_next_edges(std::move(input_info.next_edges));
ctx->needs_input_grad = input_info.needs_input_grad.release();
ctx->is_variable_input = std::move(input_info.is_variable_input);

The PyNode's next_edges now point to:

  • AccumulateGrad for dummy
  • grad_fn edges for any requires_grad=True tensor args
  • Empty edges for non-tensor or non-requires_grad args

Also reads clear_saved_tensors_on_access from the class (default False).

2f. Call forward (lines 1407-1444)

Guards: AutoGradMode(false), AutoFwGradMode(false)disables grad mode and forward-grad mode for the duration of forward.

Global state updated: Thread-local GradMode set to false, restored after.

Since setup_context is overridden (overridden_setup_context = True):

  1. Call _NoopSaveInputs.forward(*args)return torch.empty((0,)) (line 798)

    • This runs under no_grad so the resulting tensor has no grad_fn.
    • Returns a 0-sized tensor (the "noop" output).
  2. Call _NoopSaveInputs.setup_context(ctx, inputs, output) (line 801):

@staticmethod
def setup_context(ctx, inputs, output):
    tensor_indices, tensors = zip(
        *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)], strict=False
    )
    idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
    args = [None if isinstance(o, torch.Tensor) else o for o in inputs]

    def get_args(saved_tensors):
        ret = [
            saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
            for i, o in enumerate(args)
        ]
        return ret[1:]  # skip the dummy

    ctx.get_args = get_args
    ctx.save_for_backward(*tensors)
  • inputs = (dummy, kwargs, arg0, arg1, ...) as passed to forward
  • Separates tensors from non-tensors. dummy is a tensor, plus whatever user args are tensors.
  • Creates a closure get_args that reconstructs the full arg list from ctx.saved_tensors.
  • ctx.save_for_backward(*tensors) — this is the critical call.

Step 3: ctx.save_for_backward(*tensors) (C++ side)

This is implemented as THPFunction_save_for_backward in python_function.cpp. It stores the tensors into ctx->to_save. The actual SavedVariable creation happens later in _save_variables (step 5b).

Global state updated: ctx->to_save is set to a tuple of the tensor references.

Step 4: process_outputs(...) (line 1446)

Back in THPFunction_apply, after forward + setup_context return.

4a. ensure_tuple(raw_output) — the forward returned a single tensor torch.empty((0,)), so it wraps it in a 1-tuple and returns unpack_output = True (meaning we'll unwrap it at the end).

4b. cdata->clear_input_metadata() — clears the input metadata on the PyNode (no longer needed after setup).

4c. Record input_info (lines 1161-1167) — since is_executable = True:

  • grad_fn->input_info stores type/device/size of each variable input (for backward shape checking).

4d. _get_tensors_to_save(...) (line 1171)

Since overridden_setup_context = True and is_executable = True:

  • Iterates ctx->to_save (the tensors from save_for_backward).
  • Each tensor is added to both tensors_to_save (for creating SavedVariables) and to_save_if_setup_context (a set of TensorImpl* pointers, used later in _wrap_outputs to detect if an output is being saved).

Step 5: _wrap_outputs(...) (line 1179) and _save_variables(...) (line 1194)

5a. _wrap_outputs (python_function.cpp:638custom_function.cpp:445)

The output is the single torch.empty((0,)) tensor.

In _process_backward_mode_ad:

  • is_input = false (the output is a freshly-created tensor, not one of the inputs)
  • is_modified = false (nothing was marked dirty)
  • is_differentiable = True (cdata exists, not in non_differentiable set, float type is differentiable)
  • So it takes the else branch in set_history (line 349):
    impl::set_gradient_edge(var, {cdata, output_nr});
    This sets the grad_fn of the output tensor to the PyNode we created, with output_nr = 0.

Global state updated: The output tensor now has grad_fn = PyNode(_NoopSaveInputsBackward). This connects it to the autograd graph.

The output tensor's output_info is also recorded on the ctx.

5b. _save_variables (python_function.cpp:852)

For each tensor in tensors_to_save:

  • Checks if it's an output (by comparing TensorImpl pointers against the output set).
    • dummy is an input, not the output → is_output = false
    • User tensor args are inputs, not the output → is_output = false
  • Creates SavedVariable(tensor, is_output=false):

Step 6: SavedVariable constructor (saved_variable.cpp:17)

For each tensor being saved:

SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_inplace_on_view)

With is_output = false, is_inplace_on_view = false:

6a. Checks !variable.is_inference() — OK for normal tensors.

6b. Records metadata:

  • saved_version_ = variable._version() — the version counter snapshot
  • is_leaf_ = variable.is_leaf()
  • is_output_ = false

6c. SavedTensorDefaultHooks::is_enabled()False (assumption 6 — the _checkpoint_hook context manager is not yet active at this point).

So maybe_hooks = nullptr. The hooks branch is not taken.

6d. Since !is_output || is_leaf_:

  • Since is_output = false, !is_output = true, so all saved variables here go through the saved_original_ = true path. The tensor is stored directly in data_. No copy is made.

Global state updated: self->saved_variables vector now contains SavedVariable objects holding direct references to the tensors.

Step 7: Return value

Back in process_outputs:

  • unpack_output = True → unwrap the 1-tuple → return the single output tensor.

Back in THPFunction_apply → returns the output tensor to Python.

Back in Function.apply → returns it to the caller.

The result: new_frame.input_saver is now the torch.empty((0,)) tensor, but with grad_fn set to the _NoopSaveInputsBackward node. This node holds all the original inputs as SavedVariables in its saved_variables list.


Summary of global state changes

State Change
at::sequence_number Incremented by 1
Autograd graph New PyNode(_NoopSaveInputsBackward) added, with edges to the grad_fns/AccumulateGrads of the input tensors
Output tensor's grad_fn Set to the new PyNode
GradMode (thread-local) Temporarily set to false during forward, then restored
AutoFwGradMode (thread-local) Temporarily set to false during forward, then restored
ctx.saved_variables Vector of SavedVariable objects holding direct refs to dummy + all tensor args
ctx.get_args Closure stored on ctx for reconstructing the full arg list
ctx.needs_input_grad Tuple of booleans
ctx.is_variable_input Vector of booleans
ctx.input_info Vector of input metadata (type/device/size)
ctx.output_info Vector of output metadata

Key insight for checkpoint

The _NoopSaveInputs.apply(...) call happens before _checkpoint_hook(new_frame) is entered (line 1616 vs 1623). This means the saved tensor default hooks are not yet installed. So the tensors saved via ctx.save_for_backward are stored directly (not packed via checkpoint's pack_hook).

Later, when _checkpoint_hook is entered (line 1623), the user's forward function runs. Any save_for_backward calls during that forward will go through the checkpoint pack/unpack hooks. But the inputs saved by _NoopSaveInputs are saved without hooks — they're kept as direct tensor references.

During backward, when unpack_hook is called and needs to recompute, it accesses frame.input_saver.grad_fn (which is the _NoopSaveInputsBackward node), calls ctx.saved_tensors to retrieve the original inputs, and passes them to recompute_fn.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment