Skip to content

Instantly share code, notes, and snippets.

@lakshayg
Last active February 11, 2026 16:07
Show Gist options
  • Select an option

  • Save lakshayg/23b86b37e9df4cbc2e52a879d0d377ee to your computer and use it in GitHub Desktop.

Select an option

Save lakshayg/23b86b37e9df4cbc2e52a879d0d377ee to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "stRssy4MKKuz"
},
"outputs": [],
"source": [
"import torch\n",
"A = torch.rand(1,32768,8,1,32,dtype=torch.half,device='cuda')\n",
"B = torch.rand(8,192,32,dtype=torch.half,device='cuda')"
]
},
{
"cell_type": "code",
"source": [
"def f1(a, b):\n",
" return a @ b.mT\n",
"\n",
"def f2(a, b):\n",
" return torch.einsum('...ik,...jk', a, b)\n",
"\n",
"# check if both the functions produce the same output\n",
"A_dbl = A.double()\n",
"B_dbl = B.double();\n",
"max_difference = torch.max(torch.abs(f1(A_dbl, B_dbl) - f2(A_dbl, B_dbl)))\n",
"print(f\"{max_difference=}\")\n",
"del A_dbl\n",
"del B_dbl"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DpdxWvV8KUB_",
"outputId": "b380606e-b577-4075-9d7a-b8cb8332bc8f"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"max_difference=tensor(0., device='cuda:0', dtype=torch.float64)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"torch.cuda.empty_cache()\n",
"torch.cuda.reset_peak_memory_stats()\n",
"print(f1(A, B).shape)\n",
"print(torch.cuda.memory.memory_summary())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"collapsed": true,
"id": "wtZ_2KAzKm5V",
"outputId": "ef01b0a9-1eb4-4dc2-d8f2-8126386e6331"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([1, 32768, 8, 1, 192])\n",
"|===========================================================================|\n",
"| PyTorch CUDA memory summary, device ID 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n",
"|===========================================================================|\n",
"| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocated memory | 24800 KiB | 3192 MiB | 17080 MiB | 17056 MiB |\n",
"| from large pool | 24704 KiB | 3192 MiB | 17080 MiB | 17056 MiB |\n",
"| from small pool | 96 KiB | 0 MiB | 0 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| Active memory | 24800 KiB | 3192 MiB | 17080 MiB | 17056 MiB |\n",
"| from large pool | 24704 KiB | 3192 MiB | 17080 MiB | 17056 MiB |\n",
"| from small pool | 96 KiB | 0 MiB | 0 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| Requested memory | 24800 KiB | 3192 MiB | 17080 MiB | 17056 MiB |\n",
"| from large pool | 24704 KiB | 3192 MiB | 17080 MiB | 17056 MiB |\n",
"| from small pool | 96 KiB | 0 MiB | 0 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved memory | 3206 MiB | 3206 MiB | 15942 MiB | 12736 MiB |\n",
"| from large pool | 3204 MiB | 3204 MiB | 15940 MiB | 12736 MiB |\n",
"| from small pool | 2 MiB | 2 MiB | 2 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable memory | 14111 KiB | 14111 KiB | 12302 MiB | 12288 MiB |\n",
"| from large pool | 12160 KiB | 12160 KiB | 12299 MiB | 12288 MiB |\n",
"| from small pool | 1951 KiB | 1951 KiB | 2 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocations | 4 | 6 | 24 | 20 |\n",
"| from large pool | 2 | 4 | 10 | 8 |\n",
"| from small pool | 2 | 2 | 14 | 12 |\n",
"|---------------------------------------------------------------------------|\n",
"| Active allocs | 4 | 6 | 24 | 20 |\n",
"| from large pool | 2 | 4 | 10 | 8 |\n",
"| from small pool | 2 | 2 | 14 | 12 |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved segments | 5 | 5 | 8 | 3 |\n",
"| from large pool | 4 | 4 | 7 | 3 |\n",
"| from small pool | 1 | 1 | 1 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable allocs | 3 | 3 | 8 | 5 |\n",
"| from large pool | 1 | 1 | 3 | 2 |\n",
"| from small pool | 2 | 2 | 5 | 3 |\n",
"|---------------------------------------------------------------------------|\n",
"| Oversize allocations | 0 | 0 | 0 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Oversize GPU segments | 0 | 0 | 0 | 0 |\n",
"|===========================================================================|\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"torch.cuda.empty_cache()\n",
"torch.cuda.reset_peak_memory_stats()\n",
"print(f2(A, B).shape)\n",
"print(torch.cuda.memory.memory_summary())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"collapsed": true,
"id": "m5eheNyWK6Vi",
"outputId": "c216dea7-0dd7-4b22-f843-3f266af4e2fd"
},
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([1, 32768, 8, 1, 192])\n",
"|===========================================================================|\n",
"| PyTorch CUDA memory summary, device ID 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n",
"|===========================================================================|\n",
"| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocated memory | 24800 KiB | 123104 KiB | 17176 MiB | 17152 MiB |\n",
"| from large pool | 24704 KiB | 123008 KiB | 17176 MiB | 17152 MiB |\n",
"| from small pool | 96 KiB | 96 KiB | 0 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| Active memory | 24800 KiB | 123104 KiB | 17176 MiB | 17152 MiB |\n",
"| from large pool | 24704 KiB | 123008 KiB | 17176 MiB | 17152 MiB |\n",
"| from small pool | 96 KiB | 96 KiB | 0 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| Requested memory | 24800 KiB | 123104 KiB | 17176 MiB | 17152 MiB |\n",
"| from large pool | 24704 KiB | 123008 KiB | 17176 MiB | 17152 MiB |\n",
"| from small pool | 96 KiB | 96 KiB | 0 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved memory | 137216 KiB | 137216 KiB | 16038 MiB | 15904 MiB |\n",
"| from large pool | 135168 KiB | 135168 KiB | 16036 MiB | 15904 MiB |\n",
"| from small pool | 2048 KiB | 2048 KiB | 2 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable memory | 14111 KiB | 14111 KiB | 12302 MiB | 12288 MiB |\n",
"| from large pool | 12160 KiB | 12160 KiB | 12299 MiB | 12288 MiB |\n",
"| from small pool | 1951 KiB | 1951 KiB | 2 MiB | 0 MiB |\n",
"|---------------------------------------------------------------------------|\n",
"| Allocations | 4 | 5 | 25 | 21 |\n",
"| from large pool | 2 | 3 | 11 | 9 |\n",
"| from small pool | 2 | 2 | 14 | 12 |\n",
"|---------------------------------------------------------------------------|\n",
"| Active allocs | 4 | 5 | 25 | 21 |\n",
"| from large pool | 2 | 3 | 11 | 9 |\n",
"| from small pool | 2 | 2 | 14 | 12 |\n",
"|---------------------------------------------------------------------------|\n",
"| GPU reserved segments | 4 | 4 | 9 | 5 |\n",
"| from large pool | 3 | 3 | 8 | 5 |\n",
"| from small pool | 1 | 1 | 1 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Non-releasable allocs | 3 | 3 | 8 | 5 |\n",
"| from large pool | 1 | 1 | 3 | 2 |\n",
"| from small pool | 2 | 2 | 5 | 3 |\n",
"|---------------------------------------------------------------------------|\n",
"| Oversize allocations | 0 | 0 | 0 | 0 |\n",
"|---------------------------------------------------------------------------|\n",
"| Oversize GPU segments | 0 | 0 | 0 | 0 |\n",
"|===========================================================================|\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%timeit -n 100 -r 5 f1(A, B)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "almDN_V4K-ep",
"outputId": "1415fceb-9b62-48bb-e200-e618a7f926d6"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The slowest run took 742.44 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"46.6 ms ± 26.6 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%timeit -n 100 -r 5 f2(A, B)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6yOIuwW8LfgI",
"outputId": "be3b80fb-7392-47b5-f16e-d1085cdc43b9"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"10.2 ms ± 187 µs per loop (mean ± std. dev. of 5 runs, 100 loops each)\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment