Created
February 2, 2026 23:28
-
-
Save knyazer/4419899ebb4ec8b7aae12c0c610108ae to your computer and use it in GitHub Desktop.
prefill.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| @eqx.filter_jit | |
| def _make_chunks( | |
| self, | |
| token_ids: Int[Array, "batch tokens"], | |
| lengths_without_padding: Int[Array, " batch"] | None, | |
| chunk_size: int, | |
| ) -> Chunk: | |
| batch_size, sequence_length = token_ids.shape | |
| if lengths_without_padding is None: | |
| lengths_without_padding = jnp.full((batch_size,), sequence_length, dtype=jnp.int32) | |
| # If all sequences fit in a single chunk, use sequence_length as the chunk size | |
| if sequence_length <= chunk_size: | |
| chunk_size = sequence_length | |
| n_chunks = 1 | |
| else: | |
| n_chunks = (sequence_length + chunk_size - 1) // chunk_size | |
| padded_length = n_chunks * chunk_size | |
| token_ids = jnp.pad(token_ids, [(0, 0), (0, padded_length - sequence_length)]) | |
| # Reshape tokens to (num_chunks, batch, chunk_size) | |
| tokens = rearrange( | |
| token_ids, | |
| "batch (num_chunks chunk_size) -> num_chunks batch chunk_size", | |
| chunk_size=chunk_size, | |
| ) | |
| # Create position indices (num_chunks, batch, chunk_size) | |
| positions = jnp.arange(padded_length, dtype=jnp.int32) | |
| positions = jnp.repeat(positions[None, :], batch_size, axis=0) | |
| indices = rearrange( | |
| positions, | |
| "batch (num_chunks chunk_size) -> num_chunks batch chunk_size", | |
| chunk_size=chunk_size, | |
| ) | |
| # sequence_ends: for each chunk, how many valid tokens per batch item | |
| chunk_starts = jnp.arange(n_chunks, dtype=jnp.int32) * chunk_size | |
| sequence_ends = jnp.clip( | |
| lengths_without_padding[None, :] - chunk_starts[:, None], | |
| 0, | |
| chunk_size, | |
| ) | |
| # last_token_inside: whether the last valid token (at index length-1) is in this chunk | |
| last_token_idx = lengths_without_padding - 1 | |
| chunk_ends = chunk_starts + chunk_size | |
| is_last_token_inside = (last_token_idx[None, :] >= chunk_starts[:, None]) & ( | |
| last_token_idx[None, :] < chunk_ends[:, None] | |
| ) | |
| return Chunk( | |
| tokens=tokens, | |
| indices=indices, | |
| sequence_ends=sequence_ends, | |
| is_last_token_inside=is_last_token_inside, | |
| ) | |
| @eqx.filter_jit | |
| def _prefill( | |
| self, | |
| token_ids: Int[Array, "batch tokens"], | |
| state_capacity: int, | |
| lengths_without_padding: Int[Array, " batch"] | None = None, | |
| forward_pass_config: ForwardPassConfig | None = None, | |
| chunk_size: int = 512, # vllm default | |
| ) -> PrefillResults: | |
| batch_size, sequence_length = token_ids.shape | |
| if lengths_without_padding is None: | |
| lengths_without_padding = jnp.full((batch_size,), sequence_length, dtype=jnp.int32) | |
| chunks = self._make_chunks(token_ids, lengths_without_padding, chunk_size) | |
| num_chunks, _, chunk_size = chunks.tokens.shape | |
| state_capacity = max(state_capacity, num_chunks * chunk_size) | |
| state = self.model.init_static_state(batch_size, state_capacity) | |
| logits_like = jnp.zeros((batch_size, self.model.vocab_size), dtype=jnp.float32) | |
| def apply_chunk(state_and_logits: tuple, chunk: Chunk) -> tuple: | |
| state, prev_logits = state_and_logits | |
| decoder_outputs = self.model( | |
| chunk.tokens, | |
| chunk.indices, | |
| state, | |
| return_updated_state=True, | |
| lengths_without_padding=chunk.sequence_ends, | |
| forward_pass_mode=ForwardPassMode.MULTI_TOKEN, | |
| forward_pass_config=forward_pass_config, | |
| ) | |
| assert decoder_outputs.updated_state is not None | |
| chunk_logits = decoder_outputs.logits[jnp.arange(batch_size), chunk.sequence_ends - 1, :] | |
| new_logits = jnp.where(chunk.is_last_token_inside[:, None], chunk_logits, prev_logits) | |
| return (decoder_outputs.updated_state, new_logits), None | |
| (final_state, final_logits), _ = jax.lax.scan(apply_chunk, (state, logits_like), chunks) | |
| return PrefillResults( | |
| last_token_logits=final_logits, | |
| last_token_indices=jnp.maximum(lengths_without_padding - 1, 0), | |
| state=final_state, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment