Created
February 13, 2026 22:13
-
-
Save HDCharles/2e07779d85a0e5436ffd660f79945a58 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def _dist_comms_impl( | |
| keys, | |
| key_to_rank, | |
| get_data_fn=lambda key: None, | |
| comm_fn=lambda data, target_rank: None, | |
| store_data_fn=lambda key, data: None, | |
| should_store_data_fn= lambda target_rank: False, | |
| context_fn = None, | |
| ): | |
| if context_fn is None: | |
| context_fn = lambda x: contextlib.nullcontext() | |
| pending_comms = [] | |
| for key in keys: | |
| target_rank = key_to_rank[key] | |
| with context_fn(key): | |
| data = get_data_fn(key) | |
| reductions = reductions if isinstance(reductions, list) else [reductions]*len(data) | |
| comm = comm_fn(data, target_rank) | |
| pending_comms.extend([comm] if not isinstance(comm, list) else comm) | |
| if should_store_data_fn(target_rank): | |
| wait_for_comms(pending_comms) | |
| store_data_fn(key, data) | |
| wait_for_comms(pending_comms) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment