Last active
February 6, 2026 20:02
-
-
Save viveksck/85670b6b6be54c8429e9ec6f73ea755e to your computer and use it in GitHub Desktop.
ucb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| class UCBTopK: | |
| def __init__(self, num_arms): | |
| self.num_arms = num_arms | |
| self.counts = np.zeros(num_arms) | |
| self.values = np.zeros(num_arms) | |
| def select_arm(self, t): | |
| # Initialize: Pull each arm once to get baseline data | |
| if t <= self.num_arms: | |
| return t - 1 | |
| # Calculate UCB for all arms | |
| # We add a small epsilon to counts to avoid any potential div by zero | |
| exploration_term = np.sqrt((2 * np.log(t)) / self.counts) | |
| ucb_values = self.values + exploration_term | |
| return np.argmax(ucb_values) | |
| def update(self, chosen_arm, reward): | |
| self.counts[chosen_arm] += 1 | |
| # Incremental mean update: New Mean = Old Mean + (Reward - Old Mean) / N | |
| self.values[chosen_arm] += (reward - self.values[chosen_arm]) / self.counts[chosen_arm] | |
| def run_top_k_ucb(num_arms, horizon, k, true_means): | |
| agent = UCBTopK(num_arms) | |
| for t in range(1, horizon + 1): | |
| arm = agent.select_arm(t) | |
| # Simulate reward from a Normal distribution centered at the arm's true mean | |
| reward = np.random.normal(loc=true_means[arm], scale=1.0) | |
| agent.update(arm, reward) | |
| # After the horizon, identify the top K unique arms based on estimated values | |
| # argsort returns indices; [-k:] gets the indices of the k highest values | |
| top_k_indices = np.argsort(agent.values)[-k:] | |
| # Reverse to show highest mean first | |
| return np.flip(top_k_indices) | |
| # --- Execution --- | |
| n_arms = 10 | |
| k_to_find = 3 | |
| total_steps = 2000 | |
| # Define true hidden means (Arm 9, 8, and 7 are the best) | |
| hidden_means = [1.2, 0.5, 0.8, 1.5, 0.1, 0.9, 1.1, 2.0, 2.5, 3.0] | |
| result = run_top_k_ucb(n_arms, total_steps, k_to_find, hidden_means) | |
| print(f"True Top {k_to_find} arms: {np.argsort(hidden_means)[-k_to_find:][::-1]}") | |
| print(f"Identified Top {k_to_find} arms: {result}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment