Created
February 4, 2026 03:32
-
-
Save LiutongZhou/29bcfe414479d95e033a4f129b178fd3 to your computer and use it in GitHub Desktop.
MOE Parallel with Token Dropping in Jax
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Mixture of Experts (MoE) Layer with token dropping | |
| Using Ragged All-to-All Communication and Ragged Dot in JAX. | |
| """ | |
| __author__ = "Liutong Zhou" | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Self | |
| import jax | |
| import jax.numpy as jnp | |
| from jax.sharding import Mesh, PartitionSpec as P | |
| from jaxtyping import Float, Int, Array, Bool | |
| def calculate_offsets(sizes: Int[Array, "n"]) -> Int[Array, "n"]: | |
| """Calculate chunk start offsets given chunk sizes | |
| Parameters | |
| ---------- | |
| sizes : Int[Array, "n"] | |
| An array of chunk sizes | |
| Returns | |
| ------- | |
| offsets : Int[Array, "n"] | |
| An array of offsets starting at 0 such that offsets[i] is the starting index | |
| of chunk i in a flattened array. | |
| Examples | |
| ----------- | |
| >>> sizes = jnp.array([2, 3, 1]) | |
| >>> calculate_offsets(sizes) | |
| Array([0, 2, 5], dtype=int32) | |
| """ | |
| sizes_i32 = jnp.asarray(sizes, dtype=jnp.int32) | |
| # Pad with 0 at the start, remove the last element, then cumsum | |
| return jnp.cumsum(jnp.pad(sizes_i32[:-1], (1, 0))) | |
| def calculate_drop_token_mask( | |
| token_expert_ids: Int[Array, "n_tokens"], | |
| token_expert_probs: Float[Array, "n_tokens"], | |
| num_experts: int | Int[Array, ""], | |
| expert_capacity: int | Int[Array, ""], | |
| ) -> Bool[Array, "n_tokens"]: | |
| """Calculate a boolean mask to indicate which tokens should be dropped based on expert capacity. | |
| Parameters | |
| ---------- | |
| token_expert_ids : Int[Array, "n_tokens"] | |
| 1-D array of expert IDs assigned to each token. | |
| token_expert_probs : Float[Array, "n_tokens"] | |
| 1-D array of probabilities assigned to each token for its expert. | |
| num_experts : int | |
| Total number of experts. | |
| expert_capacity : int | |
| Maximum number of tokens each expert can handle. | |
| Returns | |
| ------- | |
| drop_mask : Bool[Array, "n_tokens"] | |
| A boolean mask indicating which tokens should be dropped (True) or kept (False). | |
| """ | |
| # Calculate each token's within-expert rank | |
| # Sort token experts by expert ID (ascending) and then by probability (descending) | |
| # so that after sorting, tokens for expert 0 are at the top, followed by tokens for expert 1, etc. | |
| index_to_sorted = jnp.lexsort((-token_expert_probs, token_expert_ids)) | |
| token_expert_ids_sorted_by_rank = token_expert_ids[index_to_sorted] | |
| # Vectorized within-expert rank calculation | |
| # Find where each expert segment starts in token_expert_ids_sorted_by_rank | |
| expert_load_size = jnp.bincount(token_expert_ids, length=num_experts) | |
| expert_start_offsets = calculate_offsets(expert_load_size) | |
| # token absolute index - its expert start offset gives within-expert rank | |
| token_within_expert_rank = ( | |
| jnp.arange(token_expert_ids_sorted_by_rank.shape[0]) | |
| - expert_start_offsets[token_expert_ids_sorted_by_rank] | |
| ) | |
| # Token dropping: cap the rank at expert_capacity | |
| is_dropped_sorted = token_within_expert_rank >= expert_capacity | |
| # revert the sorting to get the original token order | |
| index_to_inverse_sorted = jnp.argsort(index_to_sorted) | |
| drop_mask = is_dropped_sorted[index_to_inverse_sorted] | |
| return drop_mask | |
| @jax.tree_util.register_dataclass | |
| @dataclass(frozen=True, slots=True) | |
| class MOERaggedDispatcher: | |
| """Orchestrating the MOE token dispatching and returning across devices using ragged all-to-all communications | |
| On initialization, this class pre-calculates the array layout required to move variable-sized | |
| slices of data between devices. It computes where data should be read from | |
| (input offsets) and where it should be written to (output offsets) for both | |
| the forward dispatch and the backward return trip. | |
| Attributes | |
| ---------- | |
| send_sizes : Int[Array, "devices"] | |
| Number of items this device sends to each peer device. | |
| input_offsets : Int[Array, "devices"] | |
| Local array offsets to read data from during send. | |
| recv_sizes : Int[Array, "devices"] | |
| Number of items this device receives from each peer device. | |
| recv_offsets : Int[Array, "devices"] | |
| Local array offsets where received data will be stored. | |
| fwd_remote_output_offsets : Int[Array, "devices"] | |
| The offsets on the *receiver* devices where our sent data should be written. | |
| bwd_remote_output_offsets : Int[Array, "devices"] | |
| The offsets on the *original sender* devices where the returned results | |
| should be written. | |
| """ | |
| # Forward Pass Layout | |
| send_sizes: Int[Array, "devices"] # how much to read locally | |
| input_offsets: Int[Array, "devices"] # where to read from locally | |
| recv_sizes: Int[Array, "devices"] # how much to write locally | |
| recv_offsets: Int[Array, "devices"] # where to write locally | |
| # Remote Writing Instructions | |
| fwd_remote_output_offsets: Int[Array, "devices"] # where to write remotely | |
| # where to write remotely on return trip | |
| bwd_remote_output_offsets: Int[Array, "devices"] | |
| axis_name: str | tuple[str, ...] = field( | |
| default="ep", metadata={"static": True} | |
| ) # mark as static for JAX jit | |
| @property | |
| def num_devices(self) -> int: | |
| """Number of devices involved in all-to-all communication.""" | |
| return jax.lax.axis_size(self.axis_name) | |
| @classmethod | |
| def from_target_device_ids( | |
| cls, | |
| target_device_ids: Int[Array, "n"], | |
| *, | |
| axis_name: str | tuple[str, ...], | |
| mask: Bool[Array, "n"] | None = None, | |
| ) -> Self: | |
| """Create a globally consistent communication plan. | |
| Parameters | |
| ---------- | |
| target_device_ids : Int[Array, "n"] | |
| 1-D array indicating the target device ID for each item to be sent. | |
| axis_name : str | tuple[str, ...] | |
| Device mesh axis name(s) along which to perform all-to-all communication. | |
| mask : Bool[Array, "n"], optional | |
| Optional boolean mask indicating which items to send (True) or drop (False). | |
| Returns | |
| ------- | |
| MOERaggedDispatcher | |
| An instance of MOERaggedDispatcher with precomputed communication layout. | |
| """ | |
| num_devices = jax.lax.axis_size(axis_name) | |
| device_send_load = jnp.bincount(target_device_ids, length=num_devices) | |
| input_offsets = calculate_offsets(device_send_load) | |
| # Only send this much, drop the masked-out items | |
| device_send_load_actual = jnp.bincount( | |
| target_device_ids, | |
| weights=mask.astype(target_device_ids.dtype) if mask is not None else None, | |
| length=num_devices, | |
| ) | |
| # 1. Exchange sizes: Tell peer devices how much this device is sending; | |
| # receive how much they are sending to this device | |
| recv_sizes = jax.lax.all_to_all( | |
| device_send_load_actual, | |
| axis_name=axis_name, | |
| split_axis=0, | |
| concat_axis=0, | |
| tiled=True, | |
| ) | |
| recv_offsets = calculate_offsets(recv_sizes) | |
| # 2. Exchange Output Offsets: | |
| # a) Forward: Tell original senders where to write in receivers' buffer. | |
| fwd_remote_output_offsets = jax.lax.all_to_all( | |
| recv_offsets, | |
| axis_name=axis_name, | |
| split_axis=0, | |
| concat_axis=0, | |
| tiled=True, | |
| ) | |
| # b) Backward: Tell senders (expert devices) where to write back in receivers' (original senders') buffer. | |
| # When the expert forward is done, they need to return data to the original device. Tell | |
| # expert devices to write it exactly where tokens were originally read from. | |
| bwd_remote_output_offsets = jax.lax.all_to_all( | |
| input_offsets, | |
| axis_name=axis_name, | |
| split_axis=0, | |
| concat_axis=0, | |
| tiled=True, | |
| ) | |
| return cls( | |
| send_sizes=device_send_load_actual, | |
| input_offsets=input_offsets, | |
| recv_sizes=recv_sizes, | |
| recv_offsets=recv_offsets, | |
| fwd_remote_output_offsets=fwd_remote_output_offsets, | |
| bwd_remote_output_offsets=bwd_remote_output_offsets, | |
| axis_name=axis_name, | |
| ) | |
| def dispatch_forward[T: Array]( | |
| self, data: T, capacity: int | Int[Array, ""] | |
| ) -> tuple[T, Bool[Array, "n"]]: | |
| """Send data to expert devices. | |
| Parameters | |
| ---------- | |
| data : Array | |
| The input data (e.g. tokens) sorted by target device. If with masking, data must be | |
| sorted such that dropped items are at the end within each device group. | |
| capacity : int | |
| The total size of the receiver's buffer (must be sufficient for worst case). | |
| Returns | |
| ------- | |
| received : Array | |
| The data received from peer devices, packed contiguously in output_buffer. | |
| Shape (capacity, ...). | |
| is_valid : Bool[Array, "n"] | |
| 1-D boolean mask indicating which received items are valid in the output. | |
| received data beyond the actual received size are invalid and should be ignored. | |
| """ | |
| # pre-allocate an output buffer of static size for receiving data | |
| output_buffer = jnp.zeros((capacity,) + data.shape[1:], dtype=data.dtype) | |
| # each device sends data and returns the received data | |
| received = jax.lax.ragged_all_to_all( | |
| data, | |
| output_buffer, # for storing what this device will receive after forward dispatch | |
| self.input_offsets, # Read from here (local) | |
| self.send_sizes, # Read this amount (local) | |
| self.fwd_remote_output_offsets, # Write to here (remote) | |
| self.recv_sizes, # Write this amount (remote) | |
| axis_name=self.axis_name, | |
| ) | |
| is_valid = jnp.arange(capacity) < jnp.sum(self.recv_sizes) | |
| return received, is_valid | |
| def dispatch_backward[T: Array]( | |
| self, data: T, capacity: int | Int[Array, ""] | |
| ) -> tuple[T, Bool[Array, "n"]]: | |
| """Return processed data to the original sender devices. | |
| Parameters | |
| ---------- | |
| data : Array | |
| The processed data (must be in the same order as received). | |
| capacity : int | |
| The size of the buffer on the original sender (to restore shape). | |
| Returns | |
| ------- | |
| returned_to_original_sender : Array | |
| The results, placed back into their original slots on the sender. | |
| is_valid : Bool[Array, "n"] | |
| 1-D boolean mask indicating which returned items are valid in the output. | |
| returned data beyond the actual returned size are invalid and should be ignored. | |
| """ | |
| output_buffer = jnp.zeros((capacity,) + data.shape[1:], dtype=data.dtype) | |
| # Note: Roles of input/output offsets are effectively swapped for the return trip | |
| returned_to_original_sender = jax.lax.ragged_all_to_all( | |
| data, | |
| output_buffer, | |
| self.recv_offsets, # Read from here (local processed data) | |
| self.recv_sizes, # Read this amount | |
| self.bwd_remote_output_offsets, # Write to here (remote original sender) | |
| self.send_sizes, # Write this amount | |
| axis_name=self.axis_name, | |
| ) | |
| is_valid = jnp.arange(capacity) < jnp.sum(self.send_sizes) | |
| return returned_to_original_sender, is_valid | |
| def moe_layer( | |
| sequences: Float[Array, "batch seq hidden"], | |
| router_weights: Float[Array, "hidden experts_total"], | |
| expert_weights: Float[Array, "experts_total hidden hidden"], | |
| *, | |
| mesh: Mesh, | |
| top_k: int, | |
| expert_axis_name: str | tuple[str, ...] = "ep", | |
| capacity_factor: float = 1.2, | |
| ) -> Float[Array, "batch seq hidden"]: | |
| """Execute a Shard-Mapped Mixture of Experts layer with Ragged All-to-All Communication. | |
| Parameters | |
| ---------- | |
| sequences : Float[Array, "batch seq hidden"] | |
| Input token sequences. | |
| router_weights : Float[Array, "hidden experts_total"] | |
| Weights for the gating network. | |
| expert_weights : Float[Array, "experts_total hidden hidden"] | |
| Weights for the experts (sharded). | |
| mesh : Mesh | |
| The JAX device mesh. | |
| top_k : int | |
| Number of experts to select per token. | |
| expert_axis_name : str | tuple[str, ...], optional | |
| Name of the mesh axis for experts, by default "ep". | |
| capacity_factor : float, optional | |
| Multiplier for expert capacity, by default 1.2. | |
| Expert Capacity = (total_tokens * top_k / total_experts) * capacity_factor. | |
| Returns | |
| ------- | |
| Float[Array, "batch seq hidden"] | |
| The processed sequences. | |
| """ | |
| @jax.shard_map( | |
| mesh=mesh, | |
| in_specs=( | |
| P("dp", expert_axis_name, None), # sequences[batch@dp, seq@ep, hidden] | |
| P(None, None), # router_weights[hidden, experts_total] | |
| # expert_weights[experts_total@ep, hidden, hidden] | |
| P(expert_axis_name, None, None), | |
| ), | |
| out_specs=P("dp", expert_axis_name, None), | |
| ) | |
| def _sharded_moe_impl( | |
| x_shard: Float[Array, "local_batch local_seq hidden"], | |
| router_weights: Float[Array, "hidden experts_total"], | |
| local_expert_weights: Float[Array, "local_experts hidden hidden"], | |
| ) -> Float[Array, "local_batch local_seq hidden"]: | |
| # x_shard[local_batch{dp}, local_seq{ep}, hidden] | |
| # Setup and preparations | |
| num_devices_ep = jax.lax.axis_size(expert_axis_name) | |
| num_experts_per_device = local_expert_weights.shape[0] | |
| num_total_experts = router_weights.shape[-1] | |
| assert num_total_experts == num_devices_ep * num_experts_per_device, ( | |
| f"{num_total_experts=} must equal {num_devices_ep=} * {num_experts_per_device=}" | |
| ) | |
| # A. Calculating Token Routing | |
| # 1. Flatten Local Batch: (b, s, hidden) -> (b*s, hidden) | |
| tokens = x_shard.reshape(-1, x_shard.shape[-1]) | |
| # Average tokens per expert = Total Tokens to route / Total Experts | |
| expert_capacity = jnp.ceil( | |
| tokens.shape[0] | |
| * top_k | |
| * num_devices_ep | |
| / num_total_experts | |
| * capacity_factor | |
| ) | |
| device_capacity = expert_capacity * num_experts_per_device | |
| # 2. Routing (Top-K): (b*s, hidden) @ (hidden, experts_total) -> (b*s, experts_total) -> topk -> (b*s, k) | |
| top_k_logits, top_k_expert_ids = jax.lax.top_k(tokens @ router_weights, k=top_k) | |
| top_k_expert_probs = jax.nn.softmax(top_k_logits, axis=-1) | |
| # 3. Expand: (b*s, hidden) -> (b*s*k, hidden) | |
| expert_ids_flat = top_k_expert_ids.ravel() # (b*s, k) -> (b*s*k,) | |
| expert_probs_flat = top_k_expert_probs.ravel() # (b*s, k) -> (b*s*k,) | |
| # Repeat interleave each token k times to create dispatch tokens | |
| tokens_flat = jnp.repeat(tokens, repeats=top_k, axis=0) # (b*s*k, hidden) | |
| # Track original token segment IDs to sum k results back later: (b*s*k,) | |
| token_segment_ids = jnp.repeat( | |
| jnp.arange(tokens.shape[0]), repeats=top_k, axis=0 | |
| ) | |
| # B. Token Dispatching with Token Dropping on senders' side | |
| is_token_dropped = calculate_drop_token_mask( | |
| token_expert_ids=expert_ids_flat, | |
| token_expert_probs=expert_probs_flat, | |
| num_experts=num_total_experts, | |
| expert_capacity=expert_capacity, | |
| ) | |
| # Prepare data (e.g. tokens_flat) to be dispacthed by sorting so that after sorting, data to send are grouped by | |
| # 1. Target Device (ascending) | |
| # 2. Is Dropped? Dropped tokens go to the end within each device group (so that they are not sent) | |
| # 3. Local Expert ID within Device | |
| target_device_ids, target_local_expert_ids = jnp.divmod( | |
| expert_ids_flat, num_experts_per_device | |
| ) | |
| index_to_sorted_for_dispatch = jnp.lexsort( | |
| (target_local_expert_ids, is_token_dropped, target_device_ids) | |
| ) | |
| tokens_sorted_for_dispatch = tokens_flat[index_to_sorted_for_dispatch] | |
| expert_probs_sorted_for_dispatch = expert_probs_flat[ | |
| index_to_sorted_for_dispatch | |
| ] | |
| target_local_expert_ids_sorted_for_dispatch = target_local_expert_ids[ | |
| index_to_sorted_for_dispatch | |
| ] | |
| token_segment_ids_sorted_for_dispatch = token_segment_ids[ | |
| index_to_sorted_for_dispatch | |
| ] | |
| # Initialize Device Dispatcher | |
| dispatcher = MOERaggedDispatcher.from_target_device_ids( | |
| target_device_ids, axis_name=expert_axis_name, mask=~is_token_dropped | |
| ) | |
| # Dispatch Forward: Send tokens to expert devices, capping expert capacity on senders' side | |
| # In worst case, a receiver device may receive full expert load from all senders, need to provision receiver's capacity for that | |
| # note: device_capacity * num_devices_ep is constant wrt number of devices. As we increase devices, it does not increase. | |
| receiver_capacity = device_capacity * num_devices_ep | |
| recv_tokens, _ = dispatcher.dispatch_forward( | |
| tokens_sorted_for_dispatch, receiver_capacity | |
| ) | |
| recv_probs, _ = dispatcher.dispatch_forward( | |
| expert_probs_sorted_for_dispatch, receiver_capacity | |
| ) | |
| recv_local_expert_ids, is_valid = dispatcher.dispatch_forward( | |
| target_local_expert_ids_sorted_for_dispatch, receiver_capacity | |
| ) | |
| # C. Expert Computation | |
| # Group tokens by local expert ids (0 to experts_per_device-1) for ragged dot | |
| # Note: we move invalid tokens (0s from output buffer) to the end so they are ignored in computation | |
| safe_local_expert_ids = jnp.where( | |
| is_valid, recv_local_expert_ids, num_experts_per_device | |
| ) | |
| indices_to_sort_by_local_expert = jnp.argsort(safe_local_expert_ids) | |
| recv_tokens_by_local_expert = recv_tokens[indices_to_sort_by_local_expert] | |
| recv_probs_by_local_expert = recv_probs[indices_to_sort_by_local_expert] | |
| # Calculate batch size per local expert for ragged_dot | |
| local_expert_token_count = jnp.bincount( | |
| safe_local_expert_ids, | |
| # Effectively mask out invalid tokens | |
| weights=is_valid.astype(safe_local_expert_ids.dtype), | |
| length=num_experts_per_device, | |
| ) | |
| # Perform MOE computation: (Batch, Hidden) @ (Hidden, Hidden) for each expert group | |
| expert_outputs = jax.lax.ragged_dot( | |
| recv_tokens_by_local_expert, | |
| local_expert_weights, | |
| group_sizes=local_expert_token_count, | |
| ) | |
| # Apply gating weights | |
| expert_outputs = expert_outputs * recv_probs_by_local_expert[:, None] | |
| # D. Dispatch Back (Return Trip) | |
| # Restore Order (Undo local expert sort) | |
| indices_invert_sort_by_local_expert = jnp.argsort( | |
| indices_to_sort_by_local_expert | |
| ) | |
| packed_outputs = expert_outputs[indices_invert_sort_by_local_expert] | |
| tokens_moe_back, _ = dispatcher.dispatch_backward( | |
| packed_outputs, capacity=tokens_sorted_for_dispatch.shape[0] | |
| ) | |
| # Finally. Summing (Segment Sum) expert outputs | |
| # Sum the K contributions back into the original token slots using the segment IDs. | |
| # (b*s*k, hidden) -> (b*s, hidden) corresponding to tokens due to usage of segment_ids | |
| combined_output = jax.ops.segment_sum( | |
| tokens_moe_back, | |
| segment_ids=token_segment_ids_sorted_for_dispatch, | |
| num_segments=tokens.shape[0], | |
| ) | |
| return combined_output.reshape(x_shard.shape) | |
| return _sharded_moe_impl(sequences, router_weights, expert_weights) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment