Chapter 18: Model Training with fit() and evaluate()¶
“With great abstraction comes great productivity.
fit()
lets you train deep models with a single line.”
After understanding the fundamentals of backpropagation and gradient descent, you now know what goes on under the hood. In this chapter, we shift our focus to using TensorFlow’s high-level training interface—the fit()
and evaluate()
methods from the tf.keras.Model
class.
By the end, you’ll be able to:
- Train models with
model.fit()
using real datasets - Track performance with
model.evaluate()
andmodel.predict()
- Customize training with validation splits, callbacks, and batch sizes
- Monitor overfitting and training speed
Setup: A Simple Model¶
Let’s use the MNIST dataset and a fully connected neural net.
import tensorflow as tf
# Load & preprocess data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Define model
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
Training with fit()¶
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
model.fit(
x_train, y_train,
epochs=5,
batch_size=64,
validation_split=0.2
)
- batch_size: how many samples per gradient update
- validation_split: % of training data used for validation
Evaluation and Prediction¶
Once trained, evaluate the model:
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc:.4f}")
Make predictions:
logits = model.predict(x_test)
predictions = tf.argmax(logits, axis=1)
Using Callbacks¶
Callbacks allow you to hook into the training process. Common uses:
- Early stopping
- Model checkpointing
- Logging
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
model.fit(
x_train, y_train,
epochs=20,
validation_split=0.2,
callbacks=[callback]
)
Custom Batching with tf.data
¶
For more control, use tf.data.Dataset
:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(buffer_size=1024).batch(64)
model.fit(train_ds, epochs=5)
Monitoring Training¶
To track progress visually, use the history
object:
history = model.fit(...)
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'], label='train')
plt.plot(history.history['val_accuracy'], label='val')
plt.legend()
Summary¶
In this chapter, you learned:
- How
fit()
abstracts the entire training loop - How to validate, evaluate, and predict using simple APIs
- How callbacks improve training flexibility
- How to use
tf.data
for custom batching
The fit()
interface doesn’t replace your understanding of backprop—it supercharges it. Now you can focus on what to train, not just how to train it.