Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created December 17, 2025 04:31
Show Gist options
  • Select an option

  • Save mlazos/240cdf885938bc1dcbcbe300688fff66 to your computer and use it in GitHub Desktop.

Select an option

Save mlazos/240cdf885938bc1dcbcbe300688fff66 to your computer and use it in GitHub Desktop.
Here's the repro:
import torch
torch._dynamo.config.capture_scalar_outputs = True
def fn(x, val_tensor):
val = val_tensor.item() # Creates an unbacked float (fp64)
scaled = val * 2.0 # fp64 * fp64 = fp64 computation
return x * scaled # fp32 * fp64 -> needs downcast
x = torch.ones(4, device='cuda', dtype=torch.float32)
val_tensor = torch.tensor(0.1, device='cuda', dtype=torch.float64)
compiled = torch.compile(fn)
out = compiled(x, val_tensor)
print(f'Output: {out}')
Generated Triton kernel shows fp64 computation traced in:
tmp1 = tl.load(in_ptr1 + (0)) # Load fp64 value from tensor
tmp2 = tl.broadcast_to(tmp1, [XBLOCK])
tmp3 = tl.full([1], 2.0, tl.float64) # fp64 constant
tmp4 = tmp3 * tmp2 # fp64 * fp64 computation
tmp5 = tmp4.to(tl.float32) # Downcast to fp32
tmp6 = tmp0 * tmp5 # fp32 * fp32
Run with TORCH_LOGS="+output_code" to see the generated kernel.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment