Skip to content

Instantly share code, notes, and snippets.

@knyazer
Created February 2, 2026 23:28
Show Gist options
  • Select an option

  • Save knyazer/4419899ebb4ec8b7aae12c0c610108ae to your computer and use it in GitHub Desktop.

Select an option

Save knyazer/4419899ebb4ec8b7aae12c0c610108ae to your computer and use it in GitHub Desktop.
prefill.py
@eqx.filter_jit
def _make_chunks(
self,
token_ids: Int[Array, "batch tokens"],
lengths_without_padding: Int[Array, " batch"] | None,
chunk_size: int,
) -> Chunk:
batch_size, sequence_length = token_ids.shape
if lengths_without_padding is None:
lengths_without_padding = jnp.full((batch_size,), sequence_length, dtype=jnp.int32)
# If all sequences fit in a single chunk, use sequence_length as the chunk size
if sequence_length <= chunk_size:
chunk_size = sequence_length
n_chunks = 1
else:
n_chunks = (sequence_length + chunk_size - 1) // chunk_size
padded_length = n_chunks * chunk_size
token_ids = jnp.pad(token_ids, [(0, 0), (0, padded_length - sequence_length)])
# Reshape tokens to (num_chunks, batch, chunk_size)
tokens = rearrange(
token_ids,
"batch (num_chunks chunk_size) -> num_chunks batch chunk_size",
chunk_size=chunk_size,
)
# Create position indices (num_chunks, batch, chunk_size)
positions = jnp.arange(padded_length, dtype=jnp.int32)
positions = jnp.repeat(positions[None, :], batch_size, axis=0)
indices = rearrange(
positions,
"batch (num_chunks chunk_size) -> num_chunks batch chunk_size",
chunk_size=chunk_size,
)
# sequence_ends: for each chunk, how many valid tokens per batch item
chunk_starts = jnp.arange(n_chunks, dtype=jnp.int32) * chunk_size
sequence_ends = jnp.clip(
lengths_without_padding[None, :] - chunk_starts[:, None],
0,
chunk_size,
)
# last_token_inside: whether the last valid token (at index length-1) is in this chunk
last_token_idx = lengths_without_padding - 1
chunk_ends = chunk_starts + chunk_size
is_last_token_inside = (last_token_idx[None, :] >= chunk_starts[:, None]) & (
last_token_idx[None, :] < chunk_ends[:, None]
)
return Chunk(
tokens=tokens,
indices=indices,
sequence_ends=sequence_ends,
is_last_token_inside=is_last_token_inside,
)
@eqx.filter_jit
def _prefill(
self,
token_ids: Int[Array, "batch tokens"],
state_capacity: int,
lengths_without_padding: Int[Array, " batch"] | None = None,
forward_pass_config: ForwardPassConfig | None = None,
chunk_size: int = 512, # vllm default
) -> PrefillResults:
batch_size, sequence_length = token_ids.shape
if lengths_without_padding is None:
lengths_without_padding = jnp.full((batch_size,), sequence_length, dtype=jnp.int32)
chunks = self._make_chunks(token_ids, lengths_without_padding, chunk_size)
num_chunks, _, chunk_size = chunks.tokens.shape
state_capacity = max(state_capacity, num_chunks * chunk_size)
state = self.model.init_static_state(batch_size, state_capacity)
logits_like = jnp.zeros((batch_size, self.model.vocab_size), dtype=jnp.float32)
def apply_chunk(state_and_logits: tuple, chunk: Chunk) -> tuple:
state, prev_logits = state_and_logits
decoder_outputs = self.model(
chunk.tokens,
chunk.indices,
state,
return_updated_state=True,
lengths_without_padding=chunk.sequence_ends,
forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
forward_pass_config=forward_pass_config,
)
assert decoder_outputs.updated_state is not None
chunk_logits = decoder_outputs.logits[jnp.arange(batch_size), chunk.sequence_ends - 1, :]
new_logits = jnp.where(chunk.is_last_token_inside[:, None], chunk_logits, prev_logits)
return (decoder_outputs.updated_state, new_logits), None
(final_state, final_logits), _ = jax.lax.scan(apply_chunk, (state, logits_like), chunks)
return PrefillResults(
last_token_logits=final_logits,
last_token_indices=jnp.maximum(lengths_without_padding - 1, 0),
state=final_state,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment