Skip to content

Instantly share code, notes, and snippets.

import torch
import json
N = 2 ** 28
glist = []
mempool = torch.cuda.graph_pool_handle()
for _ in range(3):
g = torch.cuda.CUDAGraph()
glist.append(g)
import torch
x = torch.randn(1024, device="cuda")
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
v = x[0].item()
y = x + 1
print(g.pool())
import torch
x = torch.randn(1024, device="cuda")
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
# print(x)
y = x + 1
print(g.pool())
import torch
import os
def f(do_profile=True):
s = torch.cuda.Stream()
with torch.profiler.profile() as p:
for _ in range(2):
m = torch.randn(10240, 10240, dtype=torch.bfloat16, device="cuda")
t = torch.empty(1024, device="cuda")
print(t.data_ptr())
import sys
import os
from torch.nn import functional as F
import torch
from torch import nn
from triton.testing import do_bench
def bench(f, name, warmup=5, profile_mem=False, profile=False):
diff --git a/run_train.sh b/run_train.sh
index 87558a78..0a256031 100755
--- a/run_train.sh
+++ b/run_train.sh
@@ -30,6 +30,6 @@ else
PYTORCH_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
- --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
+ --local-ranks-filter ${LOG_RANK} --role rank --tee 0 \
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index 7900ecb46dc..fb5f6bb8595 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -1291,7 +1291,7 @@ class auto_chunker:
output_size_threshold = 1024 * 1024
# Don't chunk from a node if it does not 'amplify' the inputs a lot
- amplify_ratio_threshold = 8
+ amplify_ratio_threshold = 6
"""
Train llama3.1 8B.
"""
import gc
import time
import os
from dataclasses import dataclass
from torch import nn
from torch import Tensor
import torch
torch.set_default_device("cuda")
x = torch.randn(5)
t = torch.arange(2)
x[t] = 3
x[2] = 8
t2 = torch.arange(8)
x[t2] = 9
digraph dot {
subgraph cluster_1 {
label="graph_1" graph[style="dashed"];
"graph_1_node_0"[style="bold" shape="record" label="{KERNEL
| {ID | 0 (topoId: 19) | _Z12short_kernelPfS_\<\<\<1954,256,0\>\>\>}
| {{node handle | func handle} | {0x000000000A802900 | 0x000000000A32E720}}
| {accessPolicyWindow | {base_ptr | num_bytes | hitRatio | hitProp | missProp} | {0x0000000000000000 | 0 | 0.000000 | N | N}}
| {cooperative | 0}
| {priority | 0}
}"];