grain DataLoader#
DataLoader is reponsible for loading and transforming input data.
List of Members#
- grain.load(source, *, num_epochs=None, shuffle=False, seed=None, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), transformations=(), batch_size=None, drop_remainder=False, worker_count=0, read_options=None)#
Convenient method for simple pipelines on top of a data source.
- Parameters:
source (RandomAccessDataSource) – Data source to load from. This can be one of the file data sources provided by Grain, a TFDS data source (tfds.data_source(…)) or your custom data source.
num_epochs (int | None) – See IndexSampler.
shuffle (bool) – See IndexSampler.
seed (int | None) – See IndexSampler.
shard_options (ShardOptions) – See IndexSampler.
transformations (Sequence[Batch | Map | RandomMap | TfRandomMap | Filter | FlatMap | MapWithIndex]) – List of local (stateless) transformations:
batch_size (int | None) – Optional batch size. If provided will apply BatchOperation().
drop_remainder (bool) – Whether to drop partial batches.
worker_count (int | None) – Number of child processes launched to parallelize the transformations among. Zero means processing runs in the same process.
read_options (ReadOptions | None) – Read options for the data loader. See ReadOptions.
- Returns:
DataLoader for this dataset.
- Return type:
- class grain.DataLoader(*, data_source, sampler, operations=(), worker_count=0, worker_buffer_size=1, shard_options=None, read_options=None, enable_profiling=False)#
DataLoader loads and transforms input data.
- Parameters:
data_source (dataset_base.RandomAccessDataSource)
sampler (Sampler)
operations (Sequence[transforms.Transformation | Operation])
worker_count (Optional[int])
worker_buffer_size (int)
shard_options (sharding.ShardOptions | None)
read_options (options.ReadOptions | None)
enable_profiling (bool)
- __init__(*, data_source, sampler, operations=(), worker_count=0, worker_buffer_size=1, shard_options=None, read_options=None, enable_profiling=False)#
Loads and transforms input data.
- Parameters:
data_source (RandomAccessDataSource) – Responsible for retrieving individual records based on their indices.
sampler (Sampler) – Sampler is responsible for providing the index of the next record to read and transform.
operations (Sequence[Batch | Map | RandomMap | TfRandomMap | Filter | FlatMap | MapWithIndex | Operation]) – Sequence of operations (e.g. Map, Filter) applied to the data.
worker_count (int | None) – Number of child processes launched to parallelize the transformations among. Zero means processing runs in the same process. None lets the python backend choose the value.
worker_buffer_size (int) – Count of output batches to produce in advance per worker. This ensures batches are ready when the consumer requests them.
shard_options (ShardOptions | None) – Options for how data should be sharded when using multiple machines (~ JAX processes) and data parallelism.
read_options (ReadOptions | None) – Options to use for reading. See ReadOptions.
enable_profiling (bool) – If True, profiling info is logged. Note, it only supports worker_count >= 1 at the moment.
- __iter__()#
- Return type:
- property multiprocessing_options: MultiprocessingOptions#
- class grain.DataLoaderIterator(data_loader, state, validate_state=True)#
DataLoader iterator providing get/set state functionality.
This is the only iterator we expose to users. It wraps underlying MultipleProcessIterator. In order to set state, it recreates the underlying iterator fresh with a new state.
Checkpointing for DataLoaderIterator: DataLoaderIterator uses GrainPool, which distributes RecordMetadata from produced records among worker processes in a round robin fashion. Generally, some workers can process more elements than others at a given training step. Checkpointing logic goes as follows:
With each output batch produced, GrainPool emits the worker_index of The worker that processed the batch.
DataLoaderIterator keeps track of the last_seen_index at each worker.
When restoring from a state, DataLoaderIterator checks what is the minimum last_seen_index (among the last seen indices for all workers.) and which worker processed that index. GrainPool is instructed to start distributing indices to the next worker.
- Parameters:
data_loader (DataLoader)
state (_IteratorState | None)
validate_state (bool)
- __init__(data_loader, state, validate_state=True)#
- Parameters:
data_loader (DataLoader)
state (dict[str, Any] | None)
validate_state (bool)
- __iter__()#
- Return type:
- __next__()#
Return the next item from the iterator. When exhausted, raise StopIteration
- Return type:
_T
- get_state()#
- Return type:
bytes
- async load(directory)#
Loads the iterator state from a directory.
The state may be loaded and set in a background thread. The main thread should not alter the state content while the load is in progress.
- Parameters:
directory (Path) – The directory to load the state from.
- Returns:
A coroutine that has not been awaited. This is called by Orbax in a background thread to perform I/O without blocking the main thread.
- Return type:
Awaitable[None]
- async save(directory)#
Saves the iterator state to a directory.
The current state (get_state) is used for saving, so any updates to the state after returning from this method will not affect the saved checkpoint.
- Parameters:
directory (PathAwaitingCreation) – A path in the process of being created. Must call await_creation before accessing the physical path.
- Returns:
A coroutine that has not been awaited. This is called by Orbax in a background thread to perform I/O without blocking the main thread.
- Return type:
Awaitable[None]
- set_state(state)#
Sets the state for the underlying iterator.
Note that state is an implementation detail and can change in the future. :param state: state to restore the underlying iterator to.
- Parameters:
state (bytes)
- class grain.Record(metadata: grain._src.python.record.RecordMetadata, data: T)#
- Parameters:
metadata (RecordMetadata)
data (T)
- data: T#
- metadata: RecordMetadata#
- class grain.RecordMetadata(index, record_key=None, rng=None)#
RecordMetadata contains metadata about indidivual records.
Metadata can be emitted by the sampler to refer to which record to read next. In addition, they are also used to keep information about records as they flow through the pipeline from one operation to the other.
- Parameters:
index (int)
record_key (int | None)
rng (numpy.random.Generator | None)
- index: int#
- record_key: int | None#
- remove_record_key()#
Removes record key if exists.
- rng: numpy.random.Generator | None#