Last active
December 27, 2025 16:55
-
-
Save ddh0/d9c31be9a31d55c70868ab9edd877558 to your computer and use it in GitHub Desktop.
Adaptive-P (refer to: https://github.com/ggml-org/llama.cpp/pull/17927 for canonical implementation)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /// adaptive-p: select tokens near a configurable target probability over time. | |
| /// | |
| /// the adaptive-p sampler transforms the token probability distribution to favor tokens | |
| /// that fall near a user-configurable probability target. | |
| /// | |
| /// internally, the sampler maintains an exponential moving average of the *ORIGINAL* | |
| /// probabilities of selected tokens at each sampling step. it uses this EMA to compute an | |
| /// adapted target probability at each sampling step, thus maintaining the desired target | |
| /// probability over time. | |
| /// | |
| /// adaptive-p selects a token ID rather than just mutating candidates, so it must be last | |
| /// in the sampler chain (like mirostat, dist, greedy). | |
| /// | |
| /// only mild truncation before this sampler is recommended. we suggest applying min-p | |
| /// before adaptive-p as the only other active sampler in the chain. | |
| /// | |
| /// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) | |
| /// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) | |
| /// @param seed RNG seed | |
| /// | |
| /// ref: https://github.com/ggml-org/llama.cpp/pull/17927 | |
| /// | |
| struct llama_sampler_adaptive_p { | |
| const float target; // target probability (0.0 - 1.0; negative = disabled) | |
| const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99) | |
| const uint32_t seed; // RNG seed | |
| std::mt19937 rng; // RNG | |
| float weighted_sum; // sum(p_i * decay^i) | |
| float total_weight; // sum(decay^i), converges to 1/(1-decay) | |
| std::vector<float> original_probs; // pre-transform probs, cached for EMA update | |
| }; | |
| // adaptive probability transformation constants | |
| static constexpr float DISTRIBUTION_WIDTH = 0.3f; | |
| static constexpr float PEAK_LOGIT_VALUE = 5.0f; | |
| static constexpr float SHARPNESS = 4.0f; | |
| static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH; | |
| static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) { | |
| return "adaptive-p"; | |
| } | |
| static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | |
| auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; | |
| if (ctx->target < 0.0f) { | |
| // at negative target values, adaptive-p is no-op | |
| // we simply sample from the existing distribution | |
| llama_sampler_softmax_impl(cur_p, false); | |
| cur_p->selected = llama_sample_dist(cur_p, ctx->rng); | |
| return; | |
| } | |
| // softmax and store the original probabilities | |
| llama_sampler_softmax_impl(cur_p, false); | |
| ctx->original_probs.resize(cur_p->size); | |
| for (size_t i = 0; i < cur_p->size; ++i) { | |
| ctx->original_probs[i] = cur_p->data[i].p; | |
| } | |
| // compute the adapted target probability for the current sampling step | |
| auto target = std::clamp(ctx->target, 0.0f, 1.0f); | |
| float adapted_target = std::clamp( | |
| ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight), | |
| 0.0f, 1.0f | |
| ); | |
| // adaptive probability transform | |
| // | |
| // quadratic near target for fine differentiation, transitioning to linear decay in the | |
| // tails. unbounded negative logits ensure proper suppression of far-from-target tokens | |
| // after the softmax. | |
| // | |
| for (size_t i = 0; i < cur_p->size; ++i) { | |
| float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH); | |
| cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist); | |
| } | |
| // softmax and sample from the transformed distribution | |
| llama_sampler_softmax_impl(cur_p, false); | |
| const int idx = llama_sample_dist(cur_p, ctx->rng); | |
| cur_p->selected = idx; | |
| // update history with the original probability of the selected token | |
| ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum; | |
| ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; | |
| } | |
| static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) { | |
| auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; | |
| ctx->weighted_sum = 0.0f; | |
| ctx->total_weight = 0.0f; | |
| } | |
| static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) { | |
| const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx; | |
| auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed); | |
| auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx; | |
| result_ctx->rng = ctx->rng; | |
| result_ctx->weighted_sum = ctx->weighted_sum; | |
| result_ctx->total_weight = ctx->total_weight; | |
| result_ctx->original_probs.reserve(ctx->original_probs.capacity()); | |
| return result; | |
| } | |
| static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) { | |
| delete (llama_sampler_adaptive_p *) smpl->ctx; | |
| } | |
| static struct llama_sampler_i llama_sampler_adaptive_p_i = { | |
| /* .name = */ llama_sampler_adaptive_p_name, | |
| /* .accept = */ nullptr, | |
| /* .apply = */ llama_sampler_adaptive_p_apply, | |
| /* .reset = */ llama_sampler_adaptive_p_reset, | |
| /* .clone = */ llama_sampler_adaptive_p_clone, | |
| /* .free = */ llama_sampler_adaptive_p_free, | |
| }; | |
| struct llama_sampler * llama_sampler_init_adaptive_p( | |
| float target, | |
| float decay, | |
| uint32_t seed | |
| ) { | |
| auto seed_cur = get_rng_seed(seed); | |
| return llama_sampler_init( | |
| /* .iface = */ &llama_sampler_adaptive_p_i, | |
| /* .ctx = */ new llama_sampler_adaptive_p { | |
| /* .target = */ target, | |
| /* .decay = */ std::clamp(decay, 0.0f, 0.99f), | |
| /* .seed = */ seed_cur, | |
| /* .rng = */ std::mt19937(seed_cur), | |
| /* .weighted_sum = */ 0.0f, | |
| /* .total_weight = */ 0.0f, | |
| /* .original_probs = */ {}, | |
| } | |
| ); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment