You've learned how to construct efficient input pipelines using the tf.data
API, transforming raw data sources into shuffled, batched, and prefetched streams ready for consumption. Now, let's connect these pipelines to the Keras training infrastructure. The Keras API, particularly the model.fit()
, model.evaluate()
, and model.predict()
methods, is designed to work directly with tf.data.Dataset
objects, making the integration straightforward and efficient.
When you pass a tf.data.Dataset
object to model.fit()
, Keras automatically iterates over the dataset to retrieve batches of data for each training step. This eliminates the need for manual batch iteration loops and integrates cleanly with Keras features like callbacks.
For training with model.fit()
, your dataset should typically yield tuples of the form (inputs, targets)
. Keras expects each element yielded by the dataset iterator to represent one batch of data.
inputs
: This can be a single tensor (for models with one input) or a tuple/dictionary of tensors (for models with multiple inputs). The structure must match the model's input signature.targets
: Similarly, this can be a single tensor or a tuple/dictionary of tensors, corresponding to the model's output(s) and the loss function(s) being used.If your dataset yields batches like (feature_batch, label_batch)
, Keras will correctly map feature_batch
to the model's inputs and label_batch
to the expected outputs for calculating the loss.
Consider a dataset train_dataset
created using methods like tf.data.Dataset.from_tensor_slices((features, labels))
, followed by .shuffle()
, .batch()
, and .prefetch()
. You can directly pass this dataset to model.fit()
:
# Assume 'model' is a compiled Keras model
# Assume 'train_dataset' yields (features_batch, labels_batch) tuples
# Assume 'val_dataset' yields (features_batch, labels_batch) tuples
history = model.fit(train_dataset, epochs=10, validation_data=val_dataset)
Keras handles the iteration, feeding batches to the training process automatically. The same applies to model.evaluate()
:
loss, accuracy = model.evaluate(val_dataset)
print(f"Validation Loss: {loss}, Validation Accuracy: {accuracy}")
For model.predict()
, the dataset should yield only the input features. If the dataset yields (inputs, targets)
tuples, Keras will simply ignore the targets
part during prediction.
# Assume 'test_dataset' yields batches of features only, or (features, ...) tuples
predictions = model.predict(test_dataset)
steps_per_epoch
and steps
ArgumentsWhen using tf.data
, you often work with datasets whose length might not be easily determined upfront, especially if you use transformations like repeat()
(to loop infinitely over the data) or if the dataset is sourced from generators.
Finite Datasets: If Keras can determine the cardinality (number of batches) of your dataset (e.g., created from NumPy arrays or TFRecord files without repeat()
), it will automatically run through the entire dataset once per epoch. You don't need to specify the number of steps.
Infinite Datasets or Unknown Cardinality: If your dataset is infinite (e.g., uses .repeat()
) or its size cannot be determined, Keras doesn't know when one epoch ends. In this case, you must provide the steps_per_epoch
argument to model.fit()
. This integer value tells Keras how many batches to draw from the dataset to constitute one training epoch.
# Create a dataset and repeat it indefinitely
train_dataset_repeated = train_dataset.repeat()
# Define how many batches constitute one epoch
STEPS_PER_EPOCH = num_training_samples // BATCH_SIZE # Example calculation
history = model.fit(train_dataset_repeated,
epochs=10,
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=val_dataset) # val_dataset usually finite
Similarly, model.evaluate()
and model.predict()
accept a steps
argument. If you pass a dataset with unknown cardinality or an infinite dataset to these methods, you must specify the steps
argument to indicate how many batches should be used for evaluation or prediction. If the dataset is finite and steps
is not provided, they will run until the dataset is exhausted.
# Evaluate on a specific number of batches from the validation set
EVALUATION_STEPS = num_validation_samples // BATCH_SIZE # Example calculation
loss, accuracy = model.evaluate(val_dataset, steps=EVALUATION_STEPS)
# Predict on a specific number of batches from the test set
PREDICTION_STEPS = num_test_samples // BATCH_SIZE # Example calculation
predictions = model.predict(test_dataset, steps=PREDICTION_STEPS)
Choosing the right value for steps_per_epoch
is important. A common practice is to set it such that the model sees roughly the equivalent of the entire training dataset once per epoch: steps_per_epoch = total_training_samples // batch_size
.
Let's illustrate with a simple example using NumPy data.
import tensorflow as tf
import numpy as np
# 1. Generate some dummy data
num_samples = 1000
input_dim = 10
num_classes = 2
batch_size = 32
X_train = np.random.rand(num_samples, input_dim).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=num_samples).astype(np.int32)
X_val = np.random.rand(200, input_dim).astype(np.float32)
y_val = np.random.randint(0, num_classes, size=200).astype(np.int32)
# 2. Create tf.data Datasets
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=num_samples).batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
# 3. Build a simple Keras model
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(num_classes, activation='softmax') # Use softmax for multi-class
])
# 4. Compile the model
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy', # Use sparse for integer labels
metrics=['accuracy'])
# 5. Train the model using the datasets
print("Training the model using tf.data.Dataset...")
# No steps_per_epoch needed as datasets are finite
history = model.fit(train_dataset, epochs=5, validation_data=val_dataset)
print("Training finished.")
# 6. Evaluate the model
print("\nEvaluating the model...")
loss, accuracy = model.evaluate(val_dataset) # No steps needed here either
print(f"Validation Loss: {loss:.4f}, Validation Accuracy: {accuracy:.4f}")
# 7. Make predictions (using a dataset derived from validation data for simplicity)
pred_dataset = tf.data.Dataset.from_tensor_slices(X_val).batch(batch_size)
print("\nMaking predictions...")
predictions = model.predict(pred_dataset)
print(f"Predictions shape: {predictions.shape}") # Shape: (num_val_samples, num_classes)
This example demonstrates how seamlessly tf.data.Dataset
objects plug into the standard Keras workflow. The shuffle
, batch
, and prefetch
operations ensure data is efficiently prepared and fed to the model, maximizing hardware utilization, especially when combined with GPU or TPU acceleration. This integration is a fundamental aspect of building scalable machine learning workflows in TensorFlow.
© 2025 ApX Machine Learning