-
-
Save shunting314/9de35b9845a85eff01c80576098023ec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py | |
| index 7900ecb46dc..fb5f6bb8595 100644 | |
| --- a/torch/_inductor/config.py | |
| +++ b/torch/_inductor/config.py | |
| @@ -1291,7 +1291,7 @@ class auto_chunker: | |
| output_size_threshold = 1024 * 1024 | |
| # Don't chunk from a node if it does not 'amplify' the inputs a lot | |
| - amplify_ratio_threshold = 8 | |
| + amplify_ratio_threshold = 6 | |
| num_chunk = ( | |
| int(os.environ.get("TORCHINDUCTOR_CHUNKER_NUM_CHUNKS")) # type: ignore[arg-type] | |
| diff --git a/torch/_inductor/fx_passes/auto_chunker/core.py b/torch/_inductor/fx_passes/auto_chunker/core.py | |
| index 647638578a7..739c8416b5a 100644 | |
| --- a/torch/_inductor/fx_passes/auto_chunker/core.py | |
| +++ b/torch/_inductor/fx_passes/auto_chunker/core.py | |
| @@ -125,19 +125,24 @@ def find_amplifier_node(graph: Graph) -> Optional[Node]: | |
| amplifier_nodes_ratio = [] | |
| + import torch.distributed as dist | |
| + # dist.breakpoint() # TODO | |
| for node in graph.nodes: | |
| if use_tangent(node): | |
| # enter backward part of the graph | |
| + log.debug("Node %s uses tangent", node.format_node()) | |
| break | |
| # Only trigger chunking for a small set of nodes like matmul for now | |
| if node.op != "call_function" or node.target not in eligible_amplifier_node: | |
| + log.debug("Node %s not calling eligible function", node.format_node()) | |
| continue | |
| input_size = compute_tensor_size(node.args, node.kwargs) | |
| output_size = compute_tensor_size(node) | |
| if input_size == 0: | |
| + log.debug("Node %s has 0 input", node.format_node()) | |
| continue | |
| ratio = output_size / input_size | |
| @@ -148,6 +153,7 @@ def find_amplifier_node(graph: Graph) -> Optional[Node]: | |
| amplifier_nodes_ratio.append((node, ratio)) | |
| elif ratio >= 4 and output_size >= 64_000: | |
| log.debug("Node '%s' get skipped as amplifier_node due to small amplification ratio or size. ratio %s, size %s", node.format_node(), ratio, output_size) | |
| + log.debug("Node '%s' get skipped as amplifier_node due to small amplification ratio or size. ratio %s, size %s", node.format_node(), ratio, output_size) | |
| amplifier_nodes_ratio = sorted( | |
| amplifier_nodes_ratio, key=lambda x: x[1], reverse=True | |
| diff --git a/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py b/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py | |
| index 0937fd58f77..65f3cf2febe 100644 | |
| --- a/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py | |
| +++ b/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py | |
| @@ -89,9 +89,9 @@ def propagate_scale_by(nodes_with_chunking_meta: Sequence[Node]) -> None: | |
| def propagate_div(div_node: Node) -> bool: | |
| lhs_node, rhs_node = div_node.args[:2] | |
| assert isinstance(lhs_node, Node) | |
| - assert isinstance(rhs_node, Node) | |
| + assert isinstance(rhs_node, (Node, int, float)) | |
| lhs_scale_by = get_scale_by_from_node(lhs_node) | |
| - rhs_scale_by = get_scale_by_from_node(rhs_node) | |
| + rhs_scale_by = get_scale_by_from_node(rhs_node) if isinstance(rhs_node, Node) else None | |
| if lhs_scale_by and rhs_scale_by is None: | |
| update_chunking_meta(div_node, scale_by=lhs_scale_by) | |
| return True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment