Created
December 29, 2025 19:26
-
-
Save shunting314/f298f57eee075a6dda04cc945600dc80 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import sys | |
| import os | |
| from torch.nn import functional as F | |
| import torch | |
| from torch import nn | |
| from triton.testing import do_bench | |
| def bench(f, name, warmup=5, profile_mem=False, profile=False): | |
| for _ in range(warmup): | |
| f() | |
| if profile_mem: | |
| torch.cuda.memory._record_memory_history() | |
| f() | |
| torch.cuda.memory._dump_snapshot(f"{name}.pickle") | |
| if profile: | |
| with torch.profiler.profile() as prof: | |
| f() | |
| prof.export_chrome_trace(f"{name}.json") | |
| torch.cuda.reset_peak_memory_stats() | |
| ms = do_bench(f) | |
| print(f"{name}: {ms:.3f}ms") | |
| print("Peak mem: ", torch.cuda.max_memory_allocated() / 1e9) | |
| print() | |
| from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction | |
| torch.set_default_device("cuda") | |
| BT, C, V = 32768, 768, 128256 | |
| model = nn.Linear(C, V, bias=False).bfloat16() | |
| x = torch.randn(BT, C, requires_grad=True, dtype=torch.bfloat16) | |
| T = torch.randint(0, V, (BT,)) | |
| def ligerf(m, x, label): | |
| x.grad = None | |
| m.weight.grad = None | |
| out = LigerFusedLinearCrossEntropyFunction.apply(x, m.weight, label)[0] | |
| out.backward() | |
| return out | |
| def torchf(m, x, label): | |
| x.grad = None | |
| m.weight.grad = None | |
| loss = F.cross_entropy(m(x), label) | |
| loss.backward() | |
| return loss | |
| opt_torchf = torch.compile(torchf, options={"auto_chunker.enable": False}) | |
| expected = torchf(model, x, T).float() | |
| assert torch.allclose(expected, ligerf(model, x, T).float(), atol=1e-2, rtol=1e-2) | |
| assert torch.allclose(expected, opt_torchf(model, x, T).float(), atol=1e-2, rtol=1e-2) | |
| bench(lambda: ligerf(model, x, T), "liger") | |
| bench(lambda: torchf(model, x, T), "torch") | |
| bench(lambda: opt_torchf(model, x, T), "compile_no_chunking") | |
| for log_nchunk in range(2, 7): | |
| torch._dynamo.reset() | |
| nchunk = 2 ** log_nchunk | |
| # why num_chunk can not be ov | |
| autochunker_torchf = torch.compile(torchf, options={"auto_chunker.enable": True, "auto_chunker.num_chunk": nchunk}) | |
| assert torch.allclose(expected, autochunker_torchf(model, x, T).float(), atol=1e-2, rtol=1e-2) | |
| bench(lambda: autochunker_torchf(model, x, T), f"compile_{nchunk}_chunks") | |
| print("bye") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment