Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created December 26, 2025 23:39
Show Gist options
  • Select an option

  • Save shunting314/9de35b9845a85eff01c80576098023ec to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/9de35b9845a85eff01c80576098023ec to your computer and use it in GitHub Desktop.
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index 7900ecb46dc..fb5f6bb8595 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -1291,7 +1291,7 @@ class auto_chunker:
output_size_threshold = 1024 * 1024
# Don't chunk from a node if it does not 'amplify' the inputs a lot
- amplify_ratio_threshold = 8
+ amplify_ratio_threshold = 6
num_chunk = (
int(os.environ.get("TORCHINDUCTOR_CHUNKER_NUM_CHUNKS")) # type: ignore[arg-type]
diff --git a/torch/_inductor/fx_passes/auto_chunker/core.py b/torch/_inductor/fx_passes/auto_chunker/core.py
index 647638578a7..739c8416b5a 100644
--- a/torch/_inductor/fx_passes/auto_chunker/core.py
+++ b/torch/_inductor/fx_passes/auto_chunker/core.py
@@ -125,19 +125,24 @@ def find_amplifier_node(graph: Graph) -> Optional[Node]:
amplifier_nodes_ratio = []
+ import torch.distributed as dist
+ # dist.breakpoint() # TODO
for node in graph.nodes:
if use_tangent(node):
# enter backward part of the graph
+ log.debug("Node %s uses tangent", node.format_node())
break
# Only trigger chunking for a small set of nodes like matmul for now
if node.op != "call_function" or node.target not in eligible_amplifier_node:
+ log.debug("Node %s not calling eligible function", node.format_node())
continue
input_size = compute_tensor_size(node.args, node.kwargs)
output_size = compute_tensor_size(node)
if input_size == 0:
+ log.debug("Node %s has 0 input", node.format_node())
continue
ratio = output_size / input_size
@@ -148,6 +153,7 @@ def find_amplifier_node(graph: Graph) -> Optional[Node]:
amplifier_nodes_ratio.append((node, ratio))
elif ratio >= 4 and output_size >= 64_000:
log.debug("Node '%s' get skipped as amplifier_node due to small amplification ratio or size. ratio %s, size %s", node.format_node(), ratio, output_size)
+ log.debug("Node '%s' get skipped as amplifier_node due to small amplification ratio or size. ratio %s, size %s", node.format_node(), ratio, output_size)
amplifier_nodes_ratio = sorted(
amplifier_nodes_ratio, key=lambda x: x[1], reverse=True
diff --git a/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py b/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py
index 0937fd58f77..65f3cf2febe 100644
--- a/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py
+++ b/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py
@@ -89,9 +89,9 @@ def propagate_scale_by(nodes_with_chunking_meta: Sequence[Node]) -> None:
def propagate_div(div_node: Node) -> bool:
lhs_node, rhs_node = div_node.args[:2]
assert isinstance(lhs_node, Node)
- assert isinstance(rhs_node, Node)
+ assert isinstance(rhs_node, (Node, int, float))
lhs_scale_by = get_scale_by_from_node(lhs_node)
- rhs_scale_by = get_scale_by_from_node(rhs_node)
+ rhs_scale_by = get_scale_by_from_node(rhs_node) if isinstance(rhs_node, Node) else None
if lhs_scale_by and rhs_scale_by is None:
update_chunking_meta(div_node, scale_by=lhs_scale_by)
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment