Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Created February 13, 2026 22:13
Show Gist options
  • Select an option

  • Save HDCharles/2e07779d85a0e5436ffd660f79945a58 to your computer and use it in GitHub Desktop.

Select an option

Save HDCharles/2e07779d85a0e5436ffd660f79945a58 to your computer and use it in GitHub Desktop.
def _dist_comms_impl(
keys,
key_to_rank,
get_data_fn=lambda key: None,
comm_fn=lambda data, target_rank: None,
store_data_fn=lambda key, data: None,
should_store_data_fn= lambda target_rank: False,
context_fn = None,
):
if context_fn is None:
context_fn = lambda x: contextlib.nullcontext()
pending_comms = []
for key in keys:
target_rank = key_to_rank[key]
with context_fn(key):
data = get_data_fn(key)
reductions = reductions if isinstance(reductions, list) else [reductions]*len(data)
comm = comm_fn(data, target_rank)
pending_comms.extend([comm] if not isinstance(comm, list) else comm)
if should_store_data_fn(target_rank):
wait_for_comms(pending_comms)
store_data_fn(key, data)
wait_for_comms(pending_comms)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment