Chapter 18: Inference Pipeline Design¶
“Training wins the accuracy race. Inference wins the deployment game.”
Why This Chapter Matters¶
No matter how well your model performs in training, it’s useless if it fails at inference time.
Here’s where things often go wrong:
- You trained with
[0,1]
scaled images but used[-1,1]
scaling at inference - You forgot to switch to
.eval()
mode ortraining=False
- You tested with images of a different size than during training
- Your real-world images have noise, padding, or background not present in your dataset
Inference is a system, not just a .predict()
call.
This chapter shows you how to:
- Design reusable, consistent, and fault-tolerant pipelines
- Align preprocessing between training and deployment
- Apply test-time augmentation (TTA) for better accuracy
- Defend your model against bad input or unexpected formats
Conceptual Breakdown¶
🔹 What is an Inference Pipeline?¶
It’s the entire path an image takes from user input to model prediction:
- Image is uploaded, captured, or streamed
- Preprocessing (resize, normalize, etc.)
- Passed through model (in eval mode, no gradients)
- Output is decoded (softmax, argmax, etc.)
- Results are returned in user-friendly format
A mistake at any step will lead to wrong predictions.
🔹 Training vs Inference: Matching Pipelines¶
Stage | During Training | During Inference |
---|---|---|
Resize | Resize((224, 224)) |
Same exact shape required |
Normalization | Normalize(mean, std) or rescale to [0,1] |
Must match exactly |
Augmentations | RandomCrop, Flip, ColorJitter (for variety) | Disabled, or TTA only |
Mode | model.train() |
model.eval() |
Gradients | requires_grad=True |
no_grad() / tape disabled |
If you change any of the above during inference, your model may misbehave.
PyTorch Implementation¶
🔸 1. Reusable Preprocessing Function¶
from torchvision import transforms
def get_inference_transform():
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) # Must match training!
])
🔸 2. Inference Function¶
from PIL import Image
import torch
def predict_image(model, image_path, transform):
model.eval()
image = Image.open(image_path).convert("RGB")
tensor = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(tensor)
prediction = torch.argmax(output, dim=1).item()
return prediction
TensorFlow Implementation¶
🔸 1. Preprocessing Function¶
For models like MobileNet or EfficientNet, use built-in preprocessors:
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras.preprocessing import image
import numpy as np
def prepare_image_tf(img_path):
img = image.load_img(img_path, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = preprocess_input(img_array) # Handles [-1,1] scaling
return np.expand_dims(img_array, axis=0)
🔸 2. Inference Function¶
def predict_tf(model, img_path):
img_tensor = prepare_image_tf(img_path)
predictions = model.predict(img_tensor)
return np.argmax(predictions)
Additions for a Full Inference System¶
Component | Why It Matters |
---|---|
Input validation | Ensure correct shape, color channels |
Test-Time Augmentation | Improve prediction by averaging outputs |
Softmax thresholding | Avoid low-confidence predictions |
Postprocessing | Map label index → human-readable class |
Batching | Speed up inference for multiple inputs |
🔸 Optional: Test-Time Augmentation (TTA)¶
Run multiple variants of the same image and average predictions.
def tta_predict(model, image, transforms_list):
outputs = []
for t in transforms_list:
img = t(image).unsqueeze(0)
with torch.no_grad():
output = model(img)
outputs.append(output)
return torch.stack(outputs).mean(dim=0).argmax().item()
Framework Comparison Table¶
Feature | PyTorch | TensorFlow / Keras |
---|---|---|
Eval mode | model.eval() |
training=False in model(x, training=...) |
Gradient-free inference | with torch.no_grad() |
Default in model.predict() |
Reusable preprocessing | torchvision.transforms.Compose() |
keras.preprocessing or tf.image |
Built-in TTA | Manual | Manual |
Model saving | torch.save(model.state_dict()) |
model.save() to SavedModel format |
Normalization consistency | User-defined | Use keras.applications.*.preprocess_input() |
Mini-Exercise¶
Build a full inference function that:
- Accepts an image path
- Applies identical preprocessing from training
- Loads a trained model
- Switches to inference mode
- Predicts the class and returns a human-readable label
Bonus:
- Add test-time augmentation
- Log input/output shape and prediction confidence
Gotchas to Watch Out For¶
Problem | Likely Cause |
---|---|
Model always predicts same class | Forgetting .eval() or bad normalization |
High training accuracy, poor test | Mismatched preprocessing (e.g., RGB to BGR) |
Inference crashes on large input | Missing batch dimension or wrong shape |
Weird predictions at deployment | Dropout still active, or inconsistent mode |
What You Can Now Do¶
- Write a robust inference script from scratch
- Detect input shape and channel mismatches
- Reuse training transforms to guarantee consistency
- Use test-time augmentation to improve generalization
- Ship CNNs in reproducible, traceable pipelines