Skip to content

Instantly share code, notes, and snippets.

@LiutongZhou
Created February 4, 2026 03:32
Show Gist options
  • Select an option

  • Save LiutongZhou/29bcfe414479d95e033a4f129b178fd3 to your computer and use it in GitHub Desktop.

Select an option

Save LiutongZhou/29bcfe414479d95e033a4f129b178fd3 to your computer and use it in GitHub Desktop.
MOE Parallel with Token Dropping in Jax
"""Mixture of Experts (MoE) Layer with token dropping
Using Ragged All-to-All Communication and Ragged Dot in JAX.
"""
__author__ = "Liutong Zhou"
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Self
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jaxtyping import Float, Int, Array, Bool
def calculate_offsets(sizes: Int[Array, "n"]) -> Int[Array, "n"]:
"""Calculate chunk start offsets given chunk sizes
Parameters
----------
sizes : Int[Array, "n"]
An array of chunk sizes
Returns
-------
offsets : Int[Array, "n"]
An array of offsets starting at 0 such that offsets[i] is the starting index
of chunk i in a flattened array.
Examples
-----------
>>> sizes = jnp.array([2, 3, 1])
>>> calculate_offsets(sizes)
Array([0, 2, 5], dtype=int32)
"""
sizes_i32 = jnp.asarray(sizes, dtype=jnp.int32)
# Pad with 0 at the start, remove the last element, then cumsum
return jnp.cumsum(jnp.pad(sizes_i32[:-1], (1, 0)))
def calculate_drop_token_mask(
token_expert_ids: Int[Array, "n_tokens"],
token_expert_probs: Float[Array, "n_tokens"],
num_experts: int | Int[Array, ""],
expert_capacity: int | Int[Array, ""],
) -> Bool[Array, "n_tokens"]:
"""Calculate a boolean mask to indicate which tokens should be dropped based on expert capacity.
Parameters
----------
token_expert_ids : Int[Array, "n_tokens"]
1-D array of expert IDs assigned to each token.
token_expert_probs : Float[Array, "n_tokens"]
1-D array of probabilities assigned to each token for its expert.
num_experts : int
Total number of experts.
expert_capacity : int
Maximum number of tokens each expert can handle.
Returns
-------
drop_mask : Bool[Array, "n_tokens"]
A boolean mask indicating which tokens should be dropped (True) or kept (False).
"""
# Calculate each token's within-expert rank
# Sort token experts by expert ID (ascending) and then by probability (descending)
# so that after sorting, tokens for expert 0 are at the top, followed by tokens for expert 1, etc.
index_to_sorted = jnp.lexsort((-token_expert_probs, token_expert_ids))
token_expert_ids_sorted_by_rank = token_expert_ids[index_to_sorted]
# Vectorized within-expert rank calculation
# Find where each expert segment starts in token_expert_ids_sorted_by_rank
expert_load_size = jnp.bincount(token_expert_ids, length=num_experts)
expert_start_offsets = calculate_offsets(expert_load_size)
# token absolute index - its expert start offset gives within-expert rank
token_within_expert_rank = (
jnp.arange(token_expert_ids_sorted_by_rank.shape[0])
- expert_start_offsets[token_expert_ids_sorted_by_rank]
)
# Token dropping: cap the rank at expert_capacity
is_dropped_sorted = token_within_expert_rank >= expert_capacity
# revert the sorting to get the original token order
index_to_inverse_sorted = jnp.argsort(index_to_sorted)
drop_mask = is_dropped_sorted[index_to_inverse_sorted]
return drop_mask
@jax.tree_util.register_dataclass
@dataclass(frozen=True, slots=True)
class MOERaggedDispatcher:
"""Orchestrating the MOE token dispatching and returning across devices using ragged all-to-all communications
On initialization, this class pre-calculates the array layout required to move variable-sized
slices of data between devices. It computes where data should be read from
(input offsets) and where it should be written to (output offsets) for both
the forward dispatch and the backward return trip.
Attributes
----------
send_sizes : Int[Array, "devices"]
Number of items this device sends to each peer device.
input_offsets : Int[Array, "devices"]
Local array offsets to read data from during send.
recv_sizes : Int[Array, "devices"]
Number of items this device receives from each peer device.
recv_offsets : Int[Array, "devices"]
Local array offsets where received data will be stored.
fwd_remote_output_offsets : Int[Array, "devices"]
The offsets on the *receiver* devices where our sent data should be written.
bwd_remote_output_offsets : Int[Array, "devices"]
The offsets on the *original sender* devices where the returned results
should be written.
"""
# Forward Pass Layout
send_sizes: Int[Array, "devices"] # how much to read locally
input_offsets: Int[Array, "devices"] # where to read from locally
recv_sizes: Int[Array, "devices"] # how much to write locally
recv_offsets: Int[Array, "devices"] # where to write locally
# Remote Writing Instructions
fwd_remote_output_offsets: Int[Array, "devices"] # where to write remotely
# where to write remotely on return trip
bwd_remote_output_offsets: Int[Array, "devices"]
axis_name: str | tuple[str, ...] = field(
default="ep", metadata={"static": True}
) # mark as static for JAX jit
@property
def num_devices(self) -> int:
"""Number of devices involved in all-to-all communication."""
return jax.lax.axis_size(self.axis_name)
@classmethod
def from_target_device_ids(
cls,
target_device_ids: Int[Array, "n"],
*,
axis_name: str | tuple[str, ...],
mask: Bool[Array, "n"] | None = None,
) -> Self:
"""Create a globally consistent communication plan.
Parameters
----------
target_device_ids : Int[Array, "n"]
1-D array indicating the target device ID for each item to be sent.
axis_name : str | tuple[str, ...]
Device mesh axis name(s) along which to perform all-to-all communication.
mask : Bool[Array, "n"], optional
Optional boolean mask indicating which items to send (True) or drop (False).
Returns
-------
MOERaggedDispatcher
An instance of MOERaggedDispatcher with precomputed communication layout.
"""
num_devices = jax.lax.axis_size(axis_name)
device_send_load = jnp.bincount(target_device_ids, length=num_devices)
input_offsets = calculate_offsets(device_send_load)
# Only send this much, drop the masked-out items
device_send_load_actual = jnp.bincount(
target_device_ids,
weights=mask.astype(target_device_ids.dtype) if mask is not None else None,
length=num_devices,
)
# 1. Exchange sizes: Tell peer devices how much this device is sending;
# receive how much they are sending to this device
recv_sizes = jax.lax.all_to_all(
device_send_load_actual,
axis_name=axis_name,
split_axis=0,
concat_axis=0,
tiled=True,
)
recv_offsets = calculate_offsets(recv_sizes)
# 2. Exchange Output Offsets:
# a) Forward: Tell original senders where to write in receivers' buffer.
fwd_remote_output_offsets = jax.lax.all_to_all(
recv_offsets,
axis_name=axis_name,
split_axis=0,
concat_axis=0,
tiled=True,
)
# b) Backward: Tell senders (expert devices) where to write back in receivers' (original senders') buffer.
# When the expert forward is done, they need to return data to the original device. Tell
# expert devices to write it exactly where tokens were originally read from.
bwd_remote_output_offsets = jax.lax.all_to_all(
input_offsets,
axis_name=axis_name,
split_axis=0,
concat_axis=0,
tiled=True,
)
return cls(
send_sizes=device_send_load_actual,
input_offsets=input_offsets,
recv_sizes=recv_sizes,
recv_offsets=recv_offsets,
fwd_remote_output_offsets=fwd_remote_output_offsets,
bwd_remote_output_offsets=bwd_remote_output_offsets,
axis_name=axis_name,
)
def dispatch_forward[T: Array](
self, data: T, capacity: int | Int[Array, ""]
) -> tuple[T, Bool[Array, "n"]]:
"""Send data to expert devices.
Parameters
----------
data : Array
The input data (e.g. tokens) sorted by target device. If with masking, data must be
sorted such that dropped items are at the end within each device group.
capacity : int
The total size of the receiver's buffer (must be sufficient for worst case).
Returns
-------
received : Array
The data received from peer devices, packed contiguously in output_buffer.
Shape (capacity, ...).
is_valid : Bool[Array, "n"]
1-D boolean mask indicating which received items are valid in the output.
received data beyond the actual received size are invalid and should be ignored.
"""
# pre-allocate an output buffer of static size for receiving data
output_buffer = jnp.zeros((capacity,) + data.shape[1:], dtype=data.dtype)
# each device sends data and returns the received data
received = jax.lax.ragged_all_to_all(
data,
output_buffer, # for storing what this device will receive after forward dispatch
self.input_offsets, # Read from here (local)
self.send_sizes, # Read this amount (local)
self.fwd_remote_output_offsets, # Write to here (remote)
self.recv_sizes, # Write this amount (remote)
axis_name=self.axis_name,
)
is_valid = jnp.arange(capacity) < jnp.sum(self.recv_sizes)
return received, is_valid
def dispatch_backward[T: Array](
self, data: T, capacity: int | Int[Array, ""]
) -> tuple[T, Bool[Array, "n"]]:
"""Return processed data to the original sender devices.
Parameters
----------
data : Array
The processed data (must be in the same order as received).
capacity : int
The size of the buffer on the original sender (to restore shape).
Returns
-------
returned_to_original_sender : Array
The results, placed back into their original slots on the sender.
is_valid : Bool[Array, "n"]
1-D boolean mask indicating which returned items are valid in the output.
returned data beyond the actual returned size are invalid and should be ignored.
"""
output_buffer = jnp.zeros((capacity,) + data.shape[1:], dtype=data.dtype)
# Note: Roles of input/output offsets are effectively swapped for the return trip
returned_to_original_sender = jax.lax.ragged_all_to_all(
data,
output_buffer,
self.recv_offsets, # Read from here (local processed data)
self.recv_sizes, # Read this amount
self.bwd_remote_output_offsets, # Write to here (remote original sender)
self.send_sizes, # Write this amount
axis_name=self.axis_name,
)
is_valid = jnp.arange(capacity) < jnp.sum(self.send_sizes)
return returned_to_original_sender, is_valid
def moe_layer(
sequences: Float[Array, "batch seq hidden"],
router_weights: Float[Array, "hidden experts_total"],
expert_weights: Float[Array, "experts_total hidden hidden"],
*,
mesh: Mesh,
top_k: int,
expert_axis_name: str | tuple[str, ...] = "ep",
capacity_factor: float = 1.2,
) -> Float[Array, "batch seq hidden"]:
"""Execute a Shard-Mapped Mixture of Experts layer with Ragged All-to-All Communication.
Parameters
----------
sequences : Float[Array, "batch seq hidden"]
Input token sequences.
router_weights : Float[Array, "hidden experts_total"]
Weights for the gating network.
expert_weights : Float[Array, "experts_total hidden hidden"]
Weights for the experts (sharded).
mesh : Mesh
The JAX device mesh.
top_k : int
Number of experts to select per token.
expert_axis_name : str | tuple[str, ...], optional
Name of the mesh axis for experts, by default "ep".
capacity_factor : float, optional
Multiplier for expert capacity, by default 1.2.
Expert Capacity = (total_tokens * top_k / total_experts) * capacity_factor.
Returns
-------
Float[Array, "batch seq hidden"]
The processed sequences.
"""
@jax.shard_map(
mesh=mesh,
in_specs=(
P("dp", expert_axis_name, None), # sequences[batch@dp, seq@ep, hidden]
P(None, None), # router_weights[hidden, experts_total]
# expert_weights[experts_total@ep, hidden, hidden]
P(expert_axis_name, None, None),
),
out_specs=P("dp", expert_axis_name, None),
)
def _sharded_moe_impl(
x_shard: Float[Array, "local_batch local_seq hidden"],
router_weights: Float[Array, "hidden experts_total"],
local_expert_weights: Float[Array, "local_experts hidden hidden"],
) -> Float[Array, "local_batch local_seq hidden"]:
# x_shard[local_batch{dp}, local_seq{ep}, hidden]
# Setup and preparations
num_devices_ep = jax.lax.axis_size(expert_axis_name)
num_experts_per_device = local_expert_weights.shape[0]
num_total_experts = router_weights.shape[-1]
assert num_total_experts == num_devices_ep * num_experts_per_device, (
f"{num_total_experts=} must equal {num_devices_ep=} * {num_experts_per_device=}"
)
# A. Calculating Token Routing
# 1. Flatten Local Batch: (b, s, hidden) -> (b*s, hidden)
tokens = x_shard.reshape(-1, x_shard.shape[-1])
# Average tokens per expert = Total Tokens to route / Total Experts
expert_capacity = jnp.ceil(
tokens.shape[0]
* top_k
* num_devices_ep
/ num_total_experts
* capacity_factor
)
device_capacity = expert_capacity * num_experts_per_device
# 2. Routing (Top-K): (b*s, hidden) @ (hidden, experts_total) -> (b*s, experts_total) -> topk -> (b*s, k)
top_k_logits, top_k_expert_ids = jax.lax.top_k(tokens @ router_weights, k=top_k)
top_k_expert_probs = jax.nn.softmax(top_k_logits, axis=-1)
# 3. Expand: (b*s, hidden) -> (b*s*k, hidden)
expert_ids_flat = top_k_expert_ids.ravel() # (b*s, k) -> (b*s*k,)
expert_probs_flat = top_k_expert_probs.ravel() # (b*s, k) -> (b*s*k,)
# Repeat interleave each token k times to create dispatch tokens
tokens_flat = jnp.repeat(tokens, repeats=top_k, axis=0) # (b*s*k, hidden)
# Track original token segment IDs to sum k results back later: (b*s*k,)
token_segment_ids = jnp.repeat(
jnp.arange(tokens.shape[0]), repeats=top_k, axis=0
)
# B. Token Dispatching with Token Dropping on senders' side
is_token_dropped = calculate_drop_token_mask(
token_expert_ids=expert_ids_flat,
token_expert_probs=expert_probs_flat,
num_experts=num_total_experts,
expert_capacity=expert_capacity,
)
# Prepare data (e.g. tokens_flat) to be dispacthed by sorting so that after sorting, data to send are grouped by
# 1. Target Device (ascending)
# 2. Is Dropped? Dropped tokens go to the end within each device group (so that they are not sent)
# 3. Local Expert ID within Device
target_device_ids, target_local_expert_ids = jnp.divmod(
expert_ids_flat, num_experts_per_device
)
index_to_sorted_for_dispatch = jnp.lexsort(
(target_local_expert_ids, is_token_dropped, target_device_ids)
)
tokens_sorted_for_dispatch = tokens_flat[index_to_sorted_for_dispatch]
expert_probs_sorted_for_dispatch = expert_probs_flat[
index_to_sorted_for_dispatch
]
target_local_expert_ids_sorted_for_dispatch = target_local_expert_ids[
index_to_sorted_for_dispatch
]
token_segment_ids_sorted_for_dispatch = token_segment_ids[
index_to_sorted_for_dispatch
]
# Initialize Device Dispatcher
dispatcher = MOERaggedDispatcher.from_target_device_ids(
target_device_ids, axis_name=expert_axis_name, mask=~is_token_dropped
)
# Dispatch Forward: Send tokens to expert devices, capping expert capacity on senders' side
# In worst case, a receiver device may receive full expert load from all senders, need to provision receiver's capacity for that
# note: device_capacity * num_devices_ep is constant wrt number of devices. As we increase devices, it does not increase.
receiver_capacity = device_capacity * num_devices_ep
recv_tokens, _ = dispatcher.dispatch_forward(
tokens_sorted_for_dispatch, receiver_capacity
)
recv_probs, _ = dispatcher.dispatch_forward(
expert_probs_sorted_for_dispatch, receiver_capacity
)
recv_local_expert_ids, is_valid = dispatcher.dispatch_forward(
target_local_expert_ids_sorted_for_dispatch, receiver_capacity
)
# C. Expert Computation
# Group tokens by local expert ids (0 to experts_per_device-1) for ragged dot
# Note: we move invalid tokens (0s from output buffer) to the end so they are ignored in computation
safe_local_expert_ids = jnp.where(
is_valid, recv_local_expert_ids, num_experts_per_device
)
indices_to_sort_by_local_expert = jnp.argsort(safe_local_expert_ids)
recv_tokens_by_local_expert = recv_tokens[indices_to_sort_by_local_expert]
recv_probs_by_local_expert = recv_probs[indices_to_sort_by_local_expert]
# Calculate batch size per local expert for ragged_dot
local_expert_token_count = jnp.bincount(
safe_local_expert_ids,
# Effectively mask out invalid tokens
weights=is_valid.astype(safe_local_expert_ids.dtype),
length=num_experts_per_device,
)
# Perform MOE computation: (Batch, Hidden) @ (Hidden, Hidden) for each expert group
expert_outputs = jax.lax.ragged_dot(
recv_tokens_by_local_expert,
local_expert_weights,
group_sizes=local_expert_token_count,
)
# Apply gating weights
expert_outputs = expert_outputs * recv_probs_by_local_expert[:, None]
# D. Dispatch Back (Return Trip)
# Restore Order (Undo local expert sort)
indices_invert_sort_by_local_expert = jnp.argsort(
indices_to_sort_by_local_expert
)
packed_outputs = expert_outputs[indices_invert_sort_by_local_expert]
tokens_moe_back, _ = dispatcher.dispatch_backward(
packed_outputs, capacity=tokens_sorted_for_dispatch.shape[0]
)
# Finally. Summing (Segment Sum) expert outputs
# Sum the K contributions back into the original token slots using the segment IDs.
# (b*s*k, hidden) -> (b*s, hidden) corresponding to tokens due to usage of segment_ids
combined_output = jax.ops.segment_sum(
tokens_moe_back,
segment_ids=token_segment_ids_sorted_for_dispatch,
num_segments=tokens.shape[0],
)
return combined_output.reshape(x_shard.shape)
return _sharded_moe_impl(sequences, router_weights, expert_weights)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment