grain.sharding module#

APIs for sharding pipelines for distributed training.

List of Members#

class grain.sharding.ShardOptions(shard_index, shard_count, drop_remainder=False)#

Dataclass to hold options for sharding a data source.

Parameters:
  • shard_index (int)

  • shard_count (int)

  • drop_remainder (bool)

shard_index#

The index of the shard to use in this process. Must be in [0, shard_count - 1].

Type:

int

shard_count#

The total number of shards.

Type:

int

drop_remainder#

If True shard() will create even splits and drop the remainder examples (all shards will have the same number of examples). If False will distribute the remainder N over the first N shards.

Type:

bool

class grain.sharding.NoSharding#

Doesn’t shard data. Each process will load all data.

class grain.sharding.ShardByJaxProcess(drop_remainder=False)#

Shards the data across JAX processes.

Parameters:

drop_remainder (bool)