Skip to content

Instantly share code, notes, and snippets.

@guilhermeleobas
Created November 7, 2025 19:03
Show Gist options
  • Select an option

  • Save guilhermeleobas/f629d552ca431b68e3d4ec6d5f9e810f to your computer and use it in GitHub Desktop.

Select an option

Save guilhermeleobas/f629d552ca431b68e3d4ec6d5f9e810f to your computer and use it in GitHub Desktop.
vmap_hessian.py
_ = torch.manual_seed(0)
device = "cpu"
D1 = 2 # x, y
D2 = 3 # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
run_backward = False
model = nn.Sequential(
nn.Linear(D1, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, D2),
).to(device)
def predict(x):
out = model(x)
return out, out # return two outputs is needed for jacrev auxiliary object
def reference_hessian():
x_ = x.clone().requires_grad_()
ones = torch.ones(B, device=x.device)
pred, _ = predict(x_)
jacobian_rows = [None] * D2
hessian_rows = [None] * (D2 * D1)
for i in range(D2):
jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[
0
]
for i in range(D2):
for j in range(D1):
hessian_rows[i * D1 + j] = torch.autograd.grad(
jacobian_rows[i][:, j], x_, ones, create_graph=True
)[0]
jacobian = torch.stack(jacobian_rows) # [D2, B, D1]
hessian = torch.stack(hessian_rows) # [D2 * D1, B, D1]
return hessian.transpose(0, 1), pred
def functorch_hessian():
x_ = x.clone().requires_grad_()
fn = vmap(
jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True),
in_dims=0,
)
fn = torch.compile(fn, fullgraph=True)
hessian, pred = fn(
x_
) # [B, D2, D1, D1]
return hessian, pred
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment