Source code for afnio.utils.data.sampler

import random
from typing import Generic, Iterator, Optional, Sequence, Sized, TypeVar

T_co = TypeVar("T_co", covariant=True)


[docs] class Sampler(Generic[T_co]): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a :meth:`__len__` method that returns the length of the returned iterators. """ def __init__(self) -> None: raise NotImplementedError def __iter__(self) -> Iterator[T_co]: raise NotImplementedError
[docs] class SequentialSampler(Sampler[int]): r"""Samples elements sequentially, always in the same order. Args: data_source (Dataset): dataset to sample from """ data_source: Sized def __init__(self, data_source: Sized) -> None: self.data_source = data_source def __iter__(self) -> Iterator[int]: return iter(range(len(self.data_source))) def __len__(self) -> int: return len(self.data_source)
[docs] class RandomSampler(Sampler[int]): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. Args: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. seed (int): A number to set the seed for the random draws. """ data_source: Sized replacement: bool def __init__( self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, seed: Optional[int] = None, ) -> None: self.data_source = data_source self.replacement = replacement self._num_samples = num_samples self.seed = seed if not isinstance(self.replacement, bool): raise TypeError( f"replacement should be a boolean value, " f"but got replacement={self.replacement}" ) if not isinstance(self.num_samples, int) or self.num_samples <= 0: raise ValueError( f"num_samples should be a positive integer value, " f"but got num_samples={self.num_samples}" ) @property def num_samples(self) -> int: # dataset size might change at runtime if self._num_samples is None: return len(self.data_source) return self._num_samples def _is_valid_random_state(self, state) -> bool: return isinstance(state, tuple) and len(state) > 0 def __iter__(self) -> Iterator[int]: n = len(self.data_source) random.seed(self.seed) if self.replacement: for _ in range(self.num_samples // 32): yield from random.choices(range(n), k=32) yield from random.choices(range(n), k=self.num_samples % 32) else: for _ in range(self.num_samples // n): yield from random.sample(range(n), n) yield from random.sample(range(n), self.num_samples % n) def __len__(self) -> int: return self.num_samples
[docs] class WeightedRandomSampler(Sampler[int]): r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). Args: weights (sequence): a sequence of weights, not necessary summing up to one num_samples (int): number of samples to draw replacement (bool): if ``True``, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row. seed (int): A number to set the seed for the random draws. Example: >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2] """ # noqa: E501 weights: Sequence[float] num_samples: int replacement: bool def __init__( self, weights: Sequence[float], num_samples: int, replacement: bool = True, seed: Optional[int] = None, ) -> None: if ( not isinstance(num_samples, int) or isinstance(num_samples, bool) or num_samples <= 0 ): raise ValueError( f"num_samples should be a positive integer value, " f"but got num_samples={num_samples}" ) if not isinstance(replacement, bool): raise ValueError( f"replacement should be a boolean value, " f"but got replacement={replacement}" ) if len(weights) == 0 or not all(isinstance(w, (float, int)) for w in weights): raise ValueError("Weights must be a non-empty sequence of numbers.") if not replacement and num_samples > len(weights): raise ValueError( f"num_samples ({num_samples}) cannot be greater than " f"the population size ({len(weights)}) when replacement is False." ) self.weights = weights self.num_samples = num_samples self.replacement = replacement self.seed = seed def __iter__(self) -> Iterator[int]: random.seed(self.seed) total_weight = sum(self.weights) probabilities = [w / total_weight for w in self.weights] if self.replacement: yield from random.choices( population=range(len(self.weights)), weights=probabilities, k=self.num_samples, ) else: # Sample without replacement yield from random.sample(range(len(self.weights)), k=self.num_samples) def __len__(self) -> int: return self.num_samples