Created
November 7, 2025 19:03
-
-
Save guilhermeleobas/f629d552ca431b68e3d4ec6d5f9e810f to your computer and use it in GitHub Desktop.
vmap_hessian.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| _ = torch.manual_seed(0) | |
| device = "cpu" | |
| D1 = 2 # x, y | |
| D2 = 3 # u, v, p | |
| B = 10000 | |
| x = torch.randn(B, D1).to(device) | |
| run_backward = False | |
| model = nn.Sequential( | |
| nn.Linear(D1, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, D2), | |
| ).to(device) | |
| def predict(x): | |
| out = model(x) | |
| return out, out # return two outputs is needed for jacrev auxiliary object | |
| def reference_hessian(): | |
| x_ = x.clone().requires_grad_() | |
| ones = torch.ones(B, device=x.device) | |
| pred, _ = predict(x_) | |
| jacobian_rows = [None] * D2 | |
| hessian_rows = [None] * (D2 * D1) | |
| for i in range(D2): | |
| jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[ | |
| 0 | |
| ] | |
| for i in range(D2): | |
| for j in range(D1): | |
| hessian_rows[i * D1 + j] = torch.autograd.grad( | |
| jacobian_rows[i][:, j], x_, ones, create_graph=True | |
| )[0] | |
| jacobian = torch.stack(jacobian_rows) # [D2, B, D1] | |
| hessian = torch.stack(hessian_rows) # [D2 * D1, B, D1] | |
| return hessian.transpose(0, 1), pred | |
| def functorch_hessian(): | |
| x_ = x.clone().requires_grad_() | |
| fn = vmap( | |
| jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True), | |
| in_dims=0, | |
| ) | |
| fn = torch.compile(fn, fullgraph=True) | |
| hessian, pred = fn( | |
| x_ | |
| ) # [B, D2, D1, D1] | |
| return hessian, pred |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment