grain.samplers module#

Sampler APIs.

List of Members#

class grain.samplers.IndexSampler(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), shuffle=False, num_epochs=None, seed=None)#

Base index sampler for training on a single datasource.

This index sampler supports the following operations: - Sharding of the dataset. - Global shuffle of the dataset. - Repeat the dataset for a fixed number of epochs or infinitely.

Parameters:
  • num_records (int)

  • shard_options (ShardOptions)

  • shuffle (bool)

  • num_epochs (int | None)

  • seed (int | None)

num_records#

Number of records in the data source.

shard_options#

Sharding options for the dataset.

shuffle#

Whether to globally shuffle the dataset.

num_epochs#

Number of epochs to repeat the dataset. If None, the dataset will be repeated infinitely.

seed#

Seed for shuffling the dataset.

__init__(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), shuffle=False, num_epochs=None, seed=None)#
Parameters:
  • num_records (int)

  • shard_options (ShardOptions)

  • shuffle (bool)

  • num_epochs (int | None)

  • seed (int | None)

class grain.samplers.Sampler(*args, **kwargs)#

Interface for PyGrain-compatible sampler.

__getitem__(index)#

Returns the RecordMetadata for a global index.

Parameters:

index (int)

Return type:

RecordMetadata

__init__(*args, **kwargs)#
class grain.samplers.SequentialSampler(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), seed=None)#

Basic sampler implementation that provides records in order.

Parameters:
  • num_records (int)

  • shard_options (ShardOptions)

  • seed (int | None)

__init__(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), seed=None)#
Parameters:
  • num_records (int)

  • shard_options (ShardOptions)

  • seed (int | None)