While MirroredStrategy
effectively utilizes multiple GPUs on a single machine, some machine learning tasks demand even more computational power or have datasets too large to fit comfortably within a single node's memory system. When you need to scale beyond one machine, TensorFlow provides tf.distribute.MultiWorkerMirroredStrategy
.
This strategy implements synchronous data parallelism across multiple machines, often referred to as "workers". It's similar to MirroredStrategy
: each worker gets a full copy of the model, processes a unique slice of the input data, computes gradients locally, and then participates in a collective operation to synchronize these gradients across all workers before updating the model variables. The key difference is that the communication and synchronization now happen over a network connecting the different machines.
TF_CONFIG
environment variable.MirroredStrategy
, the model's variables are created and mirrored across all GPUs on all participating workers when defined within the strategy's scope.tf.data.Dataset
) is automatically sharded, usually based on the number of workers and GPUs per worker. Each GPU processes a distinct portion of the global batch. The tf.data.experimental.AutoShardPolicy
is often used to handle this distribution correctly.Overview of
MultiWorkerMirroredStrategy
with two workers, each having two GPUs. Data is sharded, gradients are computed locally, synchronized via all-reduce over the network, and used to update model replicas identically.
TF_CONFIG
VariableFor workers to find each other and coordinate, TensorFlow relies on a cluster configuration specified via the TF_CONFIG
environment variable. This variable must be set on each worker machine participating in the training job. It's a JSON string containing two main parts:
cluster
: Defines the network addresses (hostname/IP and port) of all participating workers, assigning them roles (typically just worker
).task
: Specifies the role (type
) and index (index
) of the current worker process within the cluster definition.Here's an example TF_CONFIG
for a setup with two workers:
On Worker 0:
export TF_CONFIG='{
"cluster": {
"worker": ["worker0.example.com:2222", "worker1.example.com:2222"]
},
"task": {"type": "worker", "index": 0}
}'
On Worker 1:
export TF_CONFIG='{
"cluster": {
"worker": ["worker0.example.com:2222", "worker1.example.com:2222"]
},
"task": {"type": "worker", "index": 1}
}'
"cluster"
dictionary lists all workers under the "worker"
key. The hostnames (worker0.example.com
, worker1.example.com
) and port (2222
) must be reachable between the machines."task"
dictionary tells each process its specific identity within this cluster. Worker 0 has index: 0
, and Worker 1 has index: 1
.Setting up TF_CONFIG
correctly is fundamental for multi-worker training. Orchestration systems like Kubernetes (often used with Kubeflow) typically handle injecting the appropriate TF_CONFIG
into each worker container automatically. If running manually, you must ensure this variable is set before your Python script starts.
Integrating MultiWorkerMirroredStrategy
into your Keras code is quite similar to using MirroredStrategy
. The primary steps are:
TF_CONFIG
is Set: This happens outside your Python code, in the environment where the script runs.tf.distribute.MultiWorkerMirroredStrategy()
. TensorFlow will automatically parse the TF_CONFIG
environment variable.with strategy.scope():
. This ensures variables are created in a distributed manner.tf.data.Dataset
for your input pipeline. The strategy usually works best with automatic sharding policies. Ensure your dataset loading logic is robust and efficient, as it can become a bottleneck in distributed settings.model.fit
: Use the standard Keras model.fit
API. The strategy handles the gradient aggregation and variable updates behind the scenes.import tensorflow as tf
import os
import json
# Assume TF_CONFIG is set in the environment
# Example: For worker 0
# os.environ['TF_CONFIG'] = json.dumps({
# 'cluster': {
# 'worker': ['host1:port', 'host2:port']
# },
# 'task': {'type': 'worker', 'index': 0}
# })
# 1. Instantiate the strategy
# Communication options can be specified, e.g., NCCL for GPUs
# strategy = tf.distribute.MultiWorkerMirroredStrategy(
# communication_options=tf.distribute.experimental.CommunicationOptions(
# implementation=tf.distribute.experimental.CommunicationImplementation.NCCL
# )
# )
strategy = tf.distribute.MultiWorkerMirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
# Prepare a distributed dataset
BUFFER_SIZE = 10000
GLOBAL_BATCH_SIZE = 64 * strategy.num_replicas_in_sync # Scale batch size
# Example: Create a dummy dataset
features = tf.random.uniform((1000, 10))
labels = tf.random.uniform((1000, 1))
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
# Define options for dataset distribution
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(options)
# 2. Define model and optimizer within the strategy's scope
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1)
])
optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])
print("Model and optimizer created within strategy scope.")
# 3. Train the model using model.fit
# The strategy handles distribution automatically
print("Starting training...")
history = model.fit(dataset, epochs=5, verbose=2) # Verbose=2 is often better for multi-worker
print("Training finished.")
# Saving the model usually requires saving on only one worker (chief)
# or using specific save options. Refer to TensorFlow documentation for details.
# Example: Saving on worker 0 only
# task_type = os.environ.get('TF_CONFIG')
# if task_type:
# tf_config = json.loads(task_type)
# if tf_config['task']['type'] == 'worker' and tf_config['task']['index'] == 0:
# model.save('my_multi_worker_model.keras')
# else: # Single worker case
# model.save('my_single_worker_model.keras')
num_replicas_in_sync
). You might need to adjust learning rates or other hyperparameters to account for this larger effective batch size. A common practice is linear scaling of the learning rate, although this isn't universally optimal.tf.data
pipeline correctly shards data across workers. Using AutoShardPolicy.DATA
or AutoShardPolicy.FILE
(if reading from multiple files) is generally recommended. Improper sharding can lead to workers processing overlapping data or some workers being idle.tf.keras.callbacks.BackupAndRestore
or implementing custom training loops with checkpointing strategies that can handle worker restarts.TF_CONFIG
.MultiWorkerMirroredStrategy
is a powerful tool for scaling synchronous training beyond a single machine. Its setup requires careful configuration of the TF_CONFIG
environment variable and consideration of network performance, but it allows leveraging significantly more compute resources for demanding training tasks using the familiar Keras API.
© 2025 ApX Machine Learning