grain.experimental module#
Experimental Grain APIs.
List of Members#
- grain.experimental.FlatMapTransform#
alias of
FlatMap
- class grain.experimental.DatasetOptions(*, filter_warn_threshold_ratio=_Default(value=0.9), filter_raise_threshold_ratio=_Default(value=None), execution_tracking_mode=_Default(value=<ExecutionTrackingMode.DISABLED: 1>), min_shm_size=_Default(value=0))#
Holds options used by dataset transformations.
- Parameters:
filter_warn_threshold_ratio (float | None | _Default[float])
filter_raise_threshold_ratio (float | None | _Default[None])
execution_tracking_mode (ExecutionTrackingMode | _Default[ExecutionTrackingMode])
min_shm_size (int | _Default[int])
- filter_warn_threshold_ratio#
If the ratio of filtered out elements is above these thresholds, a warning will be issued. Value None disables the check. The ratio is calculated on non-overlapping windows of 1000 elements. For instance, with filter_warn_threshold_ratio=0.9 and 901 elements out of the first 1000 (or elements 1000…2000) filtered out, a warning will be issued.
- Type:
float | None | grain._src.python.dataset.base._Default[float]
- filter_raise_threshold_ratio#
If the ratio of filtered out elements is above these thresholds, an exception will be issued. Value None disables the check.
- Type:
float | None | grain._src.python.dataset.base._Default[None]
- execution_tracking_mode#
The collection of execution statistics like total processing time taken by each transformation, number of elements produced etc. can be managed through various modes. If DISABLED, no statistics are collected.If STAGE_TIMING, the time it takes to process each transormation is collected. See ExecutionTrackingMode for more details.
- Type:
grain._src.python.dataset.base.ExecutionTrackingMode | grain._src.python.dataset.base._Default[grain._src.python.dataset.base.ExecutionTrackingMode]
- min_shm_size#
The minimum size below which numpy arrays will copied between processes rather than passed via shared memory. For smaller arrays, the overhead of using shared memory can be higher than the cost of copying.
- Type:
int | grain._src.python.dataset.base._Default[int]
- merge(other)#
Merges these options with the other.
Explicitly set options in self take precedence over options in other.
- Parameters:
other (DatasetOptions | None) – Options to merge.
- Returns:
Merged options.
- Return type:
- class grain.experimental.ExecutionTrackingMode(*values)#
Represents different modes for tracking execution statistics.
- Available modes:
- DISABLED:
No execution statistics are measured. This mode is the default.
- STAGE_TIMING:
The time taken for each transformation stage to execute is measured and recorded. This recorded time reflects the duration spent within the specific transformation to return an element, excluding the time spent in any parent transformations. The recorded time can be retrieved using grain.experimental.get_execution_summary method.
- grain.experimental.apply_transformations(ds, transformations)#
Applies transformations to a dataset.
DEPRECATED: Use ds.apply(transformations) instead.
- Parameters:
ds (_ConsistentDatasetType) – MapDataset or IterDataset to apply the transformations to.
transformations (Batch | Map | RandomMap | TfRandomMap | Filter | FlatMap | MapWithIndex | Sequence[Batch | Map | RandomMap | TfRandomMap | Filter | FlatMap | MapWithIndex]) – one or more transformations to apply.
- Returns:
Dataset of the same type with transformations applied.
- Return type:
_ConsistentDatasetType
- class grain.experimental.ElasticIterator(ds, global_batch_size, shard_options, *, read_options=ReadOptions(num_threads=16, prefetch_buffer_size=500), multiprocessing_options=None)#
Iterator supporting recovery from a checkpoint after changes in sharding.
The input dataset is expected to be unbatched and unsharded. In order to provide elasticity guarantee this iterator includes both, batching and sharding. The iterator supports elastic re-configuration by having each shard produce the same exact checkpoint (while producing different data) as long as they are advanced the same number of steps.
State of any shard can be used to restore the state of all of the shards after changes in sharding and global batch size.
This iterator explicitly disallows many-to-one transformations without a fixed ratio, like filter and generic IterDataset transformations.
- Parameters:
ds (MapDataset)
global_batch_size (int)
shard_options (ShardOptions)
read_options (ReadOptions)
multiprocessing_options (MultiprocessingOptions | None)
- __init__(ds, global_batch_size, shard_options, *, read_options=ReadOptions(num_threads=16, prefetch_buffer_size=500), multiprocessing_options=None)#
- Parameters:
ds (MapDataset)
global_batch_size (int)
shard_options (ShardOptions)
read_options (ReadOptions)
multiprocessing_options (MultiprocessingOptions | None)
- __next__()#
Return the next item from the iterator. When exhausted, raise StopIteration
- Return type:
Any
- class grain.experimental.WithOptionsIterDataset(parent, options)#
Applies options to transformations in the pipeline.
The options will apply to all transformations in the pipeline (before and after WithOptionsIterDataset). The options can be set multiple times in the pipeline, in which case they are merged. If the same option is set multiple times, the latest value takes precedence.
Example:
ds = MapDataset.range(5).to_iter_dataset() ds = WithOptionsIterDataset( ds, DatasetOptions( filter_warn_threshold_ratio=0.6, filter_raise_threshold_ratio=0.8, ), ) ds = ds.filter(...) ds = WithOptionsIterDataset( ds, DatasetOptions(filter_warn_threshold_ratio=0.7), ) ds = ds.filter(...)
In this case, the options will be:
filter_warn_threshold_ratio=0.7 filter_raise_threshold_ratio=0.8
- Parameters:
parent (IterDataset[T])
options (base.DatasetOptions)
- __init__(parent, options)#
- Parameters:
parent (IterDataset[T])
options (DatasetOptions)
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- class grain.experimental.ParquetIterDataset(path, **read_kwargs)#
An IterDataset for a parquet format file.
- Parameters:
path (str)
- __init__(path, **read_kwargs)#
Initializes ParquetIterDataset.
- Parameters:
path (str) – A path to a parquet format file.
**read_kwargs – Keyword arguments to pass to pyarrow.parquet.ParquetFile.
- __iter__()#
Returns an iterator for this dataset.
- Return type:
_ParquetDatasetIterator[T]
- class grain.experimental.TFRecordIterDataset(path)#
An IterDataset for a TFRecord format file.
- Parameters:
path (str)
- __init__(path)#
- Parameters:
path (str)
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- grain.experimental.batch_and_pad(values, *, batch_size, pad_value=0)#
Batches the given values and, if needed, pads the batch to the given size.
Can be passed to ds.batch as batch_fn to avoid the need to drop the remainder data and pad it instead.
Example usage:
ds = grain.MapDataset.range(1, 5) batch_size = 3 batch_fn = functools.partial( grain.experimental.batch_and_pad, batch_size=batch_size) ds = ds.batch(batch_size, batch_fn=batch_fn) list(ds) == [np.ndarray([1, 2, 3]), np.ndarray([4, 0, 0])]
- Parameters:
values (Sequence[T]) – The values to batch.
batch_size (int) – Target batch size. If the number of values is smaller than this, the batch is padded with pad_value to the given size.
pad_value (Any) – The value to use for padding.
- Returns:
A batch of values with a new batch dimension at the front.
- Return type:
T
- class grain.experimental.CacheIterDataset(parent)#
Caches elements of an IterDataset in memory.
- Parameters:
parent (dataset.IterDataset[T])
- __init__(parent)#
Caches elements of an IterDataset in memory.
- Parameters:
parent (IterDataset[T]) – The parent IterDataset whose elements are to be cached.
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- class grain.experimental.FlatMapMapDataset(parent, transform)#
Flat map for one-to-many split.
- Parameters:
parent (MapDataset)
transform (FlatMap)
- __getitem__(index)#
Returns the element for the index or None if missing.
- __init__(parent, transform)#
- Parameters:
parent (MapDataset)
transform (FlatMap)
- class grain.experimental.FlatMapIterDataset(parent, transform)#
Flat map for one-to-many split.
- Parameters:
parent (IterDataset)
transform (FlatMap)
- __init__(parent, transform)#
- Parameters:
parent (IterDataset)
transform (FlatMap)
- __iter__()#
Returns an iterator for this dataset.
- class grain.experimental.InterleaveIterDataset(datasets, *, cycle_length, num_make_iter_threads=1, make_iter_buffer_size=1, iter_buffer_size=1)#
Interleaves the given sequence of datasets.
The sequence can be a MapDataset.
Concurrently processes at most cycle_length iterators and interleaves their elements. If cycle_length is larger than the number of datasets, then the behavior is similar to mixing the datasets with equal proportions. If cycle_length is 1, the datasets are chained.
Can be used with mp_prefetch to parallelize reading from sources that do not support random access and are implemented as IterDataset:
def make_source(filename: str) -> grain.IterDataset: ... ds = grain.MapDataset.source(filenames).shuffle(seed=42).map(make_source) ds = grain.experimental.InterleaveIterDataset(ds, cycle_length=4) ds = ... ds = ds.mp_prefetch(ds, 2) for element in ds: ...
Element spec inference assumes that input datasets have the same element spec.
- Parameters:
datasets (Sequence[IterDataset[T] | MapDataset[T]])
cycle_length (int)
num_make_iter_threads (int)
make_iter_buffer_size (int)
iter_buffer_size (int)
- __init__(datasets, *, cycle_length, num_make_iter_threads=1, make_iter_buffer_size=1, iter_buffer_size=1)#
Initializes the InterleaveIterDataset.
- Parameters:
datasets (Sequence[IterDataset[T] | MapDataset[T]]) – A sequence of IterDataset or MapDataset objects, or a MapDataset of datasets to be interleaved.
cycle_length (int) – The maximum number of input datasets from which elements will be processed concurrently. If cycle_length is greater than the total number of datasets, all available datasets will be interleaved. If cycle_length is 1, the datasets will be processed sequentially.
num_make_iter_threads (int) – Optional. The number of threads to use for asynchronously creating new iterators and starting prefetching elements (for each iterator) from the underlying datasets. Default value is 1, with this we’ll create one background thread to asynchronously create iterators.
make_iter_buffer_size (int) – Optional. The number of iterators to create and keep ready in advance in each preparation thread. This helps in reducing latency by ensuring iterators are available when needed. Default value is 1, with this we’ll always keep the next iterator ready in advance.
iter_buffer_size (int) – Optional. The number of elements to prefetch from each iterator. Default value is 1.
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- class grain.experimental.LimitIterDataset(parent, count)#
Limits the number of elements in the dataset.
Example usage:
` list(LimitIterDataset(MapDataset.range(5).to_iter_dataset(), 2) == [0, 1] `- Parameters:
parent (IterDataset[T])
count (int)
- parent#
The dataset to limit.
- count#
The maximum number of elements to include in the dataset.
- __init__(parent, count)#
Initializes the limit dataset.
- Parameters:
parent (IterDataset[T])
count (int)
- __iter__()#
Returns an iterator for this dataset.
- Return type:
_LimitDatasetIterator[T]
- class grain.experimental.RngPool(seed)#
RNG pool.
- Parameters:
seed (int)
- acquire_rng(index, *, op_seed=0)#
Acquire RNG.
- Parameters:
index (int)
op_seed (int)
- Return type:
numpy.random.Generator
- class grain.experimental.FirstFitPackIterDataset(parent, *, length_struct, num_packing_bins, seed=0, shuffle_bins=True, shuffle_bins_group_by_feature=None, meta_features=(), pack_alignment_struct=None, padding_struct=None, max_sequences_per_bin=None)#
Implements first-fit packing of sequences.
Packing, compared to concat-and-split, avoids splitting sequences by padding instead. Larger number of packing bins reduce the amount of padding. If the number of bins is large, this can cause epoch leakage (data points from multiple epochs getting packed together).
This uses a simple first-fit packing algorithm that: 1. Creates N bins. 2. Adds elements (in the order coming from the parent) to the first bin that has enough space. 3. Once an element doesn’t fit, emits all N bins as elements. 4. (optional) Shuffles bins. 5. Loops back to 1 and starts with the element that didn’t fit.
- Parameters:
parent (IterDataset)
length_struct (Any)
num_packing_bins (int)
seed (int)
shuffle_bins (bool)
shuffle_bins_group_by_feature (str | None)
meta_features (Sequence[str])
pack_alignment_struct (Any)
padding_struct (Any)
max_sequences_per_bin (int | None)
- __init__(parent, *, length_struct, num_packing_bins, seed=0, shuffle_bins=True, shuffle_bins_group_by_feature=None, meta_features=(), pack_alignment_struct=None, padding_struct=None, max_sequences_per_bin=None)#
Creates a dataset that packs sequences using the first-fit strategy.
- Parameters:
parent (IterDataset) – Parent dataset with variable length sequences.
length_struct (Any) – Target sequence length for each feature.
num_packing_bins (int) – Number of bins to pack sequences into.
seed (int) – Random seed for shuffling bins.
shuffle_bins (bool) – Whether to shuffle bins after packing.
shuffle_bins_group_by_feature (str | None) – Feature to group by for shuffling.
meta_features (Sequence[str]) – Meta features that do not need packing logic.
pack_alignment_struct (Any) – Optional per-feature alignment values.
padding_struct (Any) – Optional per-feature padding values.
max_sequences_per_bin (int | None) – Optional maximum number of input sequences that can be packed into a bin
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- class grain.experimental.BestFitPackIterDataset(parent, *, length_struct, num_packing_bins, seed=0, shuffle_bins=True, shuffle_bins_group_by_feature=None, meta_features=(), pack_alignment_struct=None, padding_struct=None, max_sequences_per_bin=None)#
Implements best-fit packing of sequences.
The best-fit algorithm attempts to pack elements more efficiently than first-fit by placing each new element into the bin that will leave the smallest remaining space (i.e., the “tightest” fit). This can lead to less overall padding compared to the simpler first-fit approach, especially when element sizes vary significantly.
- Parameters:
parent (IterDataset)
length_struct (Any)
num_packing_bins (int)
seed (int)
shuffle_bins (bool)
shuffle_bins_group_by_feature (str | None)
meta_features (Sequence[str])
pack_alignment_struct (Any)
padding_struct (Any)
max_sequences_per_bin (int | None)
- __init__(parent, *, length_struct, num_packing_bins, seed=0, shuffle_bins=True, shuffle_bins_group_by_feature=None, meta_features=(), pack_alignment_struct=None, padding_struct=None, max_sequences_per_bin=None)#
Creates a dataset that packs sequences using the best-fit strategy.
- Parameters:
parent (IterDataset) – Parent dataset with variable length sequences.
length_struct (Any) – Target sequence length for each feature.
num_packing_bins (int) – Number of bins to pack sequences into.
seed (int) – Random seed for shuffling bins.
shuffle_bins (bool) – Whether to shuffle bins after packing.
shuffle_bins_group_by_feature (str | None) – Feature to group by for shuffling.
meta_features (Sequence[str]) – Meta features that do not need packing logic.
pack_alignment_struct (Any) – Optional per-feature alignment values.
padding_struct (Any) – Optional per-feature padding values.
max_sequences_per_bin (int | None) – Optional maximum number of input sequences that can be packed into a bin
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- class grain.experimental.BOSHandling(*values)#
The BOS handling done inside a packing algorithm.
- class grain.experimental.ConcatThenSplitIterDataset(parent, *, length_struct, meta_features=(), split_full_length_features=True, bos_handling=BOSHandling.DO_NOTHING, bos_features=(), bos_token_id=None)#
Implements concat-then-split packing for sequence features.
This assumes that elements of the parent dataset are unnested dictionaries and entries are either scalars or NumPy arrays. The first dimension is considered the sequence dimension and its size may vary between elements. All other dimensions must be the same size for all elements. Scalars are treated as 1-dimensional arrays of size 1.
On a high level this concatenates the underlying dataset and then splits it at target sequence lengths intervals. This is well defined for the case of a single feature. For multiple features we start with an empty buffer and concatenate elements until at least one feature is fully packed. As an optimization, elements from the parent dataset that are already fully packed are passed through in priority. When the buffer contains enough elements to fill at least one feature to its target sequence length, we pack the buffer. The last element might not fully fit and will be split. The remainder of the split stays in the buffer.
When packing features we also create {feature_name}_positions and {feature_name}_segment_ids features. They are 1D arrays of size sequence_length. Segment IDs start at 1 and enumerate the elements of the packed element. Positions indicate the position within the unpacked sequence.
Features can be “meta features” in which case they are never split and we do not create *_positions and *_segment_ids features for them.
- Parameters:
parent (dataset.IterDataset)
length_struct (Mapping[str, int])
meta_features (Collection[str])
split_full_length_features (bool)
bos_handling (BOSHandling)
bos_features (Collection[str])
bos_token_id (int | None)
- __init__(parent, *, length_struct, meta_features=(), split_full_length_features=True, bos_handling=BOSHandling.DO_NOTHING, bos_features=(), bos_token_id=None)#
Creates a dataset that concat-then-splits sequences from the parent.
- Parameters:
parent (IterDataset) – The parent dataset.
length_struct (Mapping[str, int]) – Mapping from feature name to target sequence length.
meta_features (Collection[str]) – Set of feature names that are considered meta features. Meta features are never split and will be duplicated when other features of the same element are split. Otherwise, meta features are packed normally (they have their own sequence length). No *_positions and *_segment_ids features are created for meta features.
split_full_length_features (bool) – Whether full-length features are split, or they are considered packed and passed through in priority. Setting split_full_length_features=False is an optimization when some sequences already have the target length, and you don’t want them to be split. This optimization is not used by default.
bos_handling (BOSHandling) – The instructions for handling BOS tokens (by default, no BOS token is added).
bos_features (Collection[str]) – The features to which BOS handling is applied in case BOS is used.
bos_token_id (int | None) – The token indicating BOS in case BOS is used.
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- grain.experimental.multithread_prefetch(ds, num_threads, buffer_size, sequential_slice=False)#
Uses a pool of threads to prefetch elements ahead of time.
This is a thread-based alternative to multiprocess_prefetch intended to be used with free-threaded Python.
It works by sharding the input dataset into num_threads shards, and interleaving them. Each shard is read by a separate thread inside InterleaveIterDataset.
- Parameters:
ds (IterDataset[T]) – The parent dataset to prefetch from.
num_threads (int) – The number of threads to use for prefetching. If 0, prefetching is disabled and this is a no-op.
buffer_size (int) – The size of the prefetch buffer for each thread.
sequential_slice (bool) – Whether to use sequential slicing.
- Returns:
An IterDataset that prefetches elements from ds using multiple threads.
- Return type:
IterDataset[T]
- class grain.experimental.ThreadPrefetchIterDataset(parent, *, prefetch_buffer_size)#
Iterable dataset that uses a synchronized queue for prefetching.
This is a thread-based alternative to MultiprocessPrefetchIterDataset.
- Parameters:
parent (dataset.IterDataset[T])
prefetch_buffer_size (int)
- parent#
The parent dataset to prefetch from.
- prefetch_buffer_size#
The size of the prefetch buffer. Must be greater than or equal to 0. If 0, prefetching is disabled and this is a noop.
- __init__(parent, *, prefetch_buffer_size)#
- Parameters:
parent (IterDataset[T])
prefetch_buffer_size (int)
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- class grain.experimental.ThreadPrefetchDatasetIterator(parent, prefetch_buffer_size)#
Iterator that performs prefetching using a synchronized queue.
- Parameters:
parent (CheckpointableIterator[T])
prefetch_buffer_size (int)
- __init__(parent, prefetch_buffer_size)#
- Parameters:
parent (CheckpointableIterator[T])
prefetch_buffer_size (int)
- __next__()#
Return the next item from the iterator. When exhausted, raise StopIteration
- class grain.experimental.RebatchIterDataset(parent, batch_size, drop_remainder=False)#
Rebatches the input PyTree elements.
- Parameters:
parent (dataset.IterDataset)
batch_size (int)
drop_remainder (bool)
- __init__(parent, batch_size, drop_remainder=False)#
An IterDataset that rebatches elements.
- Parameters:
parent (IterDataset) – The parent IterDataset whose elements are to be rebatched.
batch_size (int) – The number of elements to batch together.
drop_remainder (bool) – Whether to drop the last batch if it is smaller than batch_size.
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- class grain.experimental.RepeatIterDataset(parent, num_epochs=None)#
Repeats the underlying dataset for num_epochs.
If num_epochs is None, repeats indefinitely. Note that unlike RepeatMapDataset, RepeatIterDataset does not support re-seeding for each epoch. Each epoch will be identical.
- Parameters:
parent (IterDataset[T])
num_epochs (int | None)
- __init__(parent, num_epochs=None)#
- Parameters:
parent (IterDataset[T])
num_epochs (int | None)
- __iter__()#
Returns an iterator for this dataset.
- Return type:
_RepeatDatasetIterator[T]
- class grain.experimental.WindowShuffleMapDataset(parent, *, window_size, seed)#
Shuffles the parent dataset within a given window.
Shuffles the retrieval index within a range, given by window_size. Each unique index corresponds to exactly one shuffled index (i.e. there is a one-to-one mapping and hence a guarantee that no shuffled indices are repeated within a given window).
- Parameters:
parent (dataset.MapDataset)
window_size (int)
seed (int)
- __getitem__(index)#
Returns the element for the index or None if missing.
- __init__(parent, *, window_size, seed)#
- Parameters:
parent (MapDataset)
window_size (int)
seed (int)
- class grain.experimental.WindowShuffleIterDataset(parent, *, window_size, seed)#
Shuffles the parent dataset within a given window.
Fetches window_size elements from the parent iterator and returns them in shuffled order. Each window is shuffled with different seed derived from the input seed.
- Parameters:
parent (dataset.IterDataset)
window_size (int)
seed (int)
- __init__(parent, *, window_size, seed)#
- Parameters:
parent (IterDataset)
window_size (int)
seed (int)
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- class grain.experimental.ZipMapDataset(parents)#
Combines MapDatasets of the same length to return a tuple of items.
- Parameters:
parents (Sequence[dataset.MapDataset[T]])
- __getitem__(index)#
Returns the element for the index or None if missing.
- __init__(parents)#
- Parameters:
parents (Sequence[MapDataset[T]])
- class grain.experimental.ZipIterDataset(parents, *, strict=True)#
Combines IterDatasets of the same length to return a tuple of items.
- Parameters:
parents (Sequence[dataset.IterDataset[T]])
strict (bool)
- __init__(parents, *, strict=True)#
- Parameters:
parents (Sequence[IterDataset[T]])
strict (bool)
- __iter__()#
Returns an iterator for this dataset.
- Return type:
- grain.experimental.index_shuffle()#
- grain.experimental.assert_equal_output_after_checkpoint(ds)#
Tests restoring an iterator to various checkpointed states.
- Parameters:
ds (Any) – The dataset to test. It is recommended to use a small dataset, potentially created using grain.python.experimental.LimitIterDataset, to restrict the number of steps being tested. The underlying dataset iterator must implement get_state and set_state for checkpointing.
- grain.experimental.device_put(ds, device, *, cpu_buffer_size=4, device_buffer_size=2)#
Moves the data to the given devices with prefetching.
Stage 1: A CPU-side prefetch buffer. Stage 2: Per-device buffers for elements already transferred to the device.
- Parameters:
ds (IterDataset) – Dataset to prefetch.
device – same arguments as in jax.device_put.
cpu_buffer_size (int) – Number of elements to prefetch on CPU.
device_buffer_size (int) – Number of elements to prefetch per device.
- Returns:
Dataset with the elements prefetched to the devices.
- Return type:
- class grain.experimental.PerformanceConfig(multiprocessing_options: grain._src.python.options.MultiprocessingOptions | None = None, read_options: grain._src.python.options.ReadOptions | None = None)#
- Parameters:
multiprocessing_options (MultiprocessingOptions | None)
read_options (ReadOptions | None)
- grain.experimental.pick_performance_config(ds, *, ram_budget_mb, max_workers, max_buffer_size, samples_to_check=5)#
Analyzes element size to choose an optimal number of workers for a MultiprocessPrefetchIterDataset.
- Parameters:
ds (IterDataset) – The input dataset.
ram_budget_mb (int | None) – The user predicted RAM budget in megabytes.
max_workers (int | None) – The maximum number of processes to use.
max_buffer_size (int | None) – The maximum buffer size to use.
samples_to_check (int) – The number of samples to check to estimate element size.
- Returns:
A PerformanceConfig object containing the optimal number of workers.
- Return type:
- grain.experimental.get_element_spec(ds)#
Returns specification of the elements produced by this dataset.
Does not instantiate iterator, perform any data reads or transformations.
- Parameters:
ds (MapDataset | IterDataset) – MapDataset or IterDataset to get the element spec from.
- Return type:
Any
- grain.experimental.set_next_index(ds_iter, index)#
Sets the next index for the dataset iterator.
- Parameters:
ds_iter (DatasetIterator)
index (int)
- Return type:
None
- grain.experimental.get_next_index(ds_iter)#
Returns the next index for the dataset iterator.
- Parameters:
ds_iter (DatasetIterator)
- Return type:
int