grain.samplers module#
Sampler APIs.
List of Members#
- class grain.samplers.IndexSampler(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), shuffle=False, num_epochs=None, seed=None)#
Base index sampler for training on a single datasource.
This index sampler supports the following operations: - Sharding of the dataset. - Global shuffle of the dataset. - Repeat the dataset for a fixed number of epochs or infinitely.
- Parameters:
num_records (int)
shard_options (ShardOptions)
shuffle (bool)
num_epochs (int | None)
seed (int | None)
- num_records#
Number of records in the data source.
- shard_options#
Sharding options for the dataset.
- shuffle#
Whether to globally shuffle the dataset.
- num_epochs#
Number of epochs to repeat the dataset. If None, the dataset will be repeated infinitely.
- seed#
Seed for shuffling the dataset.
- __init__(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), shuffle=False, num_epochs=None, seed=None)#
- Parameters:
num_records (int)
shard_options (ShardOptions)
shuffle (bool)
num_epochs (int | None)
seed (int | None)
- class grain.samplers.Sampler(*args, **kwargs)#
Interface for PyGrain-compatible sampler.
- __getitem__(index)#
Returns the RecordMetadata for a global index.
- Parameters:
index (int)
- Return type:
- __init__(*args, **kwargs)#
- class grain.samplers.SequentialSampler(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), seed=None)#
Basic sampler implementation that provides records in order.
- Parameters:
num_records (int)
shard_options (ShardOptions)
seed (int | None)
- __init__(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), seed=None)#
- Parameters:
num_records (int)
shard_options (ShardOptions)
seed (int | None)