Alright, let's put the theory into practice. We'll build a simple FastAPI service that loads a pre-trained machine learning model and exposes an endpoint to make predictions. This exercise combines data validation using Pydantic with the model integration techniques discussed earlier in this chapter.
First, you need a trained machine learning model saved to a file. For this example, we'll assume you have a simple classifier trained on the Iris dataset using scikit-learn and saved using joblib
.
If you don't have one readily available, you can create a basic model like this:
# train_save_model.py
import joblib
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
# Load the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
# Train a simple Logistic Regression model
model = LogisticRegression(max_iter=200)
model.fit(X, y)
# Save the trained model
model_filename = 'iris_classifier.joblib'
joblib.dump(model, model_filename)
print(f"Model trained and saved to {model_filename}")
# Expected output mapping: {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
print(f"Target names: {list(iris.target_names)}")
Run this script (python train_save_model.py
) to generate the iris_classifier.joblib
file in your project directory. This model expects four input features: sepal length, sepal width, petal length, and petal width (all in cm). It predicts one of three Iris species: setosa, versicolor, or virginica.
Let's organize our service. Create a small project directory with the following structure:
fastapi_ml_service/
├── iris_classifier.joblib # Your saved model file
├── models.py # Pydantic models for request/response
└── main.py # Your FastAPI application code
models.py
)We need Pydantic models to define the structure of our input data and the prediction response. Create the models.py
file:
# models.py
from pydantic import BaseModel, Field
from typing import List
class IrisFeatures(BaseModel):
"""Input features for Iris prediction."""
sepal_length: float = Field(..., gt=0, description="Sepal length in cm")
sepal_width: float = Field(..., gt=0, description="Sepal width in cm")
petal_length: float = Field(..., gt=0, description="Petal length in cm")
petal_width: float = Field(..., gt=0, description="Petal width in cm")
class Config:
# Example for FastAPI documentation
schema_extra = {
"example": {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
}
class PredictionOut(BaseModel):
"""Prediction output schema."""
predicted_class_id: int = Field(..., description="Predicted class index (0, 1, or 2)")
predicted_class_name: str = Field(..., description="Predicted class name ('setosa', 'versicolor', 'virginica')")
probabilities: List[float] = Field(..., description="List of probabilities for each class [setosa, versicolor, virginica]")
Here:
IrisFeatures
defines the four input features required by our model. We use Field
to add validation (must be greater than 0) and descriptions, which FastAPI uses for automatic documentation. We also include an example payload.PredictionOut
defines the structure of our response, including the predicted class index, the corresponding name, and the probabilities for each class.main.py
)Now, let's write the core application logic in main.py
. We will load the model at startup and create a /predict
endpoint.
# main.py
import joblib
import numpy as np
from fastapi import FastAPI, HTTPException
from models import IrisFeatures, PredictionOut # Import Pydantic models
# --- Application Setup ---
app = FastAPI(
title="Iris Prediction Service",
description="A simple API to predict Iris species using a pre-trained model.",
version="0.1.0",
)
# --- Model Loading ---
# Load the model at application startup.
# For larger applications, consider dependency injection.
model_path = "iris_classifier.joblib"
try:
model = joblib.load(model_path)
print(f"Model loaded successfully from {model_path}")
# Define class names based on the Iris dataset standard order
class_names = ['setosa', 'versicolor', 'virginica']
except FileNotFoundError:
print(f"Error: Model file not found at {model_path}")
model = None # Set model to None if loading fails
except Exception as e:
print(f"Error loading model: {e}")
model = None
# --- API Endpoints ---
@app.get("/")
def read_root():
"""Root endpoint providing basic API information."""
return {"message": "Welcome to the Iris Prediction API!"}
@app.post("/predict", response_model=PredictionOut)
async def predict_iris(features: IrisFeatures):
"""
Predict the Iris species based on input features.
Takes sepal length, sepal width, petal length, and petal width,
returns the predicted class ID, class name, and probabilities.
"""
if model is None:
raise HTTPException(status_code=503, detail="Model is not loaded or unavailable.")
# 1. Convert input data to the format expected by the model
# (scikit-learn models usually expect a 2D NumPy array)
input_data = np.array([[
features.sepal_length,
features.sepal_width,
features.petal_length,
features.petal_width
]])
# 2. Make prediction
try:
prediction_id = model.predict(input_data)
probabilities = model.predict_proba(input_data)
except Exception as e:
# Handle potential errors during prediction
raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
# 3. Format the response
predicted_class_index = int(prediction_id[0]) # Get the first element
if predicted_class_index < 0 or predicted_class_index >= len(class_names):
raise HTTPException(status_code=500, detail="Prediction index out of bounds.")
predicted_class_name = class_names[predicted_class_index]
prediction_probabilities = probabilities[0].tolist() # Get probabilities for the first (and only) input
return PredictionOut(
predicted_class_id=predicted_class_index,
predicted_class_name=predicted_class_name,
probabilities=prediction_probabilities
)
# --- Run the Application (Optional, for direct execution) ---
# Typically, you'd run this using Uvicorn from the command line.
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
Key steps in main.py
:
FastAPI
, joblib
, numpy
, HTTPException
, and our Pydantic models.joblib.load()
to load the iris_classifier.joblib
file when the application starts. Basic error handling is included. We also define class_names
corresponding to the model's output./predict
endpoint:
IrisFeatures
Pydantic model. FastAPI automatically handles parsing the incoming JSON and validating it against this model. If validation fails, FastAPI returns a 422 Unprocessable Entity error automatically.response_model=PredictionOut
. FastAPI uses this to validate the outgoing response, filter data (only fields defined in PredictionOut
are returned), and generate documentation.IrisFeatures
object into a 2D NumPy array, as expected by scikit-learn's predict
and predict_proba
methods.model.predict()
to get the class ID and model.predict_proba()
to get the class probabilities.PredictionOut
model. FastAPI automatically converts this Pydantic object into a JSON response.Start the Server: Open your terminal in the fastapi_ml_service
directory and run:
uvicorn main:app --reload --host 0.0.0.0 --port 8000
main:app
: Tells Uvicorn to find the app
object inside the main.py
file.--reload
: Automatically restarts the server when code changes (useful during development).--host 0.0.0.0
: Makes the server accessible from other machines on your network (or Docker containers later).--port 8000
: Specifies the port to run on.Access the Documentation: Open your web browser and go to http://localhost:8000/docs
. You should see the interactive Swagger UI documentation generated by FastAPI, showing your /
and /predict
endpoints, including the schemas defined by Pydantic.
Test with curl
: Open another terminal and send a POST request to the /predict
endpoint:
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}'
You should receive a JSON response similar to this (probabilities might vary slightly):
{
"predicted_class_id": 0,
"predicted_class_name": "setosa",
"probabilities": [0.97..., 0.02..., 0.00...]
}
Test Validation: Try sending invalid data (e.g., missing a field or providing non-numeric data):
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"sepal_length": 5.1,
"sepal_width": "not-a-number",
"petal_length": 1.4,
"petal_width": 0.2
}'
FastAPI (thanks to Pydantic) will automatically return a 422 Unprocessable Entity
error with details about the validation failure:
{
"detail": [
{
"loc": [
"body",
"sepal_width"
],
"msg": "value is not a valid float",
"type": "type_error.float"
}
]
}
You have now successfully built a functioning ML prediction service using FastAPI! It loads a model, defines clear data contracts with Pydantic for input and output, and serves predictions through a well-defined API endpoint. This forms a solid foundation for deploying more complex models. In later chapters, we will explore structuring larger applications, testing, handling asynchronous operations, and containerization.
© 2025 ApX Machine Learning