Skip to content

Instantly share code, notes, and snippets.

@jingwangsg
Last active October 8, 2024 09:59
Show Gist options
  • Select an option

  • Save jingwangsg/694d3a3a078632e70a020a6e2aa20832 to your computer and use it in GitHub Desktop.

Select an option

Save jingwangsg/694d3a3a078632e70a020a6e2aa20832 to your computer and use it in GitHub Desktop.
ratio sampler
import torch
import random
class RatioSampler(torch.utils.data.Sampler):
def __init__(self, dataset, ratio=1.0, shuffle=True):
self.dataset_size = len(dataset)
self.ratio = ratio
self.shuffle = shuffle
# Calculate the number of samples per episode
self.num_samples_per_episode = max(1, int(self.dataset_size * self.ratio))
# Initialize indices and shuffle if needed
self.indices = list(range(self.dataset_size))
if self.shuffle:
random.shuffle(self.indices)
# Pointer to keep track of current position in indices
self.current_index = 0
def __iter__(self):
episode_indices = []
while len(episode_indices) < self.num_samples_per_episode:
remaining = self.dataset_size - self.current_index
num_needed = self.num_samples_per_episode - len(episode_indices)
num_to_take = min(remaining, num_needed)
# Add indices to the current episode
episode_indices.extend(self.indices[self.current_index:self.current_index + num_to_take])
self.current_index += num_to_take
# If we've reached the end of the dataset, reset
if self.current_index >= self.dataset_size:
self.current_index = 0
if self.shuffle:
random.shuffle(self.indices)
return iter(episode_indices)
def __len__(self):
return self.num_samples_per_episode
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment