-
-
Save thomasahle/4c1e85e5842d01b007a8d10f5fed3a18 to your computer and use it in GitHub Desktop.
| import torch | |
| from functorch import vmap, grad | |
| from torch.autograd import Function | |
| sigmoid = torch.sigmoid | |
| sigmoid_grad = vmap(vmap(grad(sigmoid))) | |
| class TopK(Function): | |
| @staticmethod | |
| def forward(ctx, xs, k): | |
| ts, ps = _find_ts(xs, k) | |
| ctx.save_for_backward(xs, ts) | |
| return ps | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| # Compute vjp, that is grad_output.T @ J. | |
| xs, ts = ctx.saved_tensors | |
| # Let v = sigmoid'(x + t) | |
| v = sigmoid_grad(xs + ts) | |
| s = v.sum(dim=1, keepdims=True) | |
| # Jacobian is -vv.T/s + diag(v) | |
| uv = grad_output * v | |
| t1 = - uv.sum(dim=1, keepdims=True) * v / s | |
| return t1 + uv, None | |
| @torch.no_grad() | |
| def _find_ts(xs, k): | |
| b, n = xs.shape | |
| assert 0 < k < n | |
| # Lo should be small enough that all sigmoids are in the 0 area. | |
| # Similarly Hi is large enough that all are in their 1 area. | |
| lo = -xs.max(dim=1, keepdims=True).values - 10 | |
| hi = -xs.min(dim=1, keepdims=True).values + 10 | |
| for _ in range(64): | |
| mid = (hi + lo)/2 | |
| mask = sigmoid(xs + mid).sum(dim=1) < k | |
| lo[mask] = mid[mask] | |
| hi[~mask] = mid[~mask] | |
| ts = (lo + hi)/2 | |
| return ts, sigmoid(xs + ts) | |
| topk = TopK.apply | |
| xs = torch.randn(2, 3) | |
| ps = topk(xs, 2) | |
| print(xs, ps, ps.sum(dim=1)) | |
| from torch.autograd import gradcheck | |
| input = torch.randn(20, 10, dtype=torch.double, requires_grad=True) | |
| for k in range(1, 10): | |
| print(k, gradcheck(topk, (input, k), eps=1e-6, atol=1e-4)) |
Is there any place I can read about the logic behind the implementation?
I believe this pose will be helpful. https://math.stackexchange.com/questions/3280757/differentiable-top-k-function
If you are using this Soft TopK function, you may also want to combine it with BCE loss.
I have an updated gist here that does exactly that: https://gist.github.com/thomasahle/c72d11a5bd62f5f6187764f6a9bb4319
Can I use this for hard selection ?
Tell me more?
Suppose there are n points (n, c), and a score function mapping them to scores (n, 1). Then select top k points based on scores.
Can I use this for hard selection ?
I suppose you can use Straight-Through Estimator to do that. According to the Gumbel Softmax code from pytorch:
https://github.com/pytorch/pytorch/blob/v2.9.0/torch/nn/functional.py#L2139
In our case, I think y_soft is the output of the TopK, after using torch.topk get the corresponding index as y_hard, we can use
y_hard - y_soft.detach() + y_soft
to achieve aim while keep the gradient.
Ah I see, I added some more details in the writeup here: https://thomasahle.com/blog/differentiable_topk.html
Is there any place I can read about the logic behind the implementation?