Let's put the principles discussed in this chapter into practice. We'll take the hypothetical machine learning prediction service we might have built in Chapter 3 and refactor it for better organization using APIRouter
. Then, we'll write tests using TestClient
to ensure its functionality and reliability.
Assume our initial prediction service looks something like this (likely in a single main.py
file):
# main_before_refactor.py (Illustrative example)
from fastapi import FastAPI
from pydantic import BaseModel
import joblib # Or your preferred model loading library
# Assume model is pre-trained and saved as 'model.joblib'
# Assume necessary preprocessing steps are defined elsewhere or simple
# --- Data Models (from Chapter 2) ---
class InputFeatures(BaseModel):
feature1: float
feature2: float
# ... other features
class PredictionOutput(BaseModel):
prediction: float # Or appropriate type
# --- Application Setup ---
app = FastAPI(title="Simple ML Prediction Service")
# --- Model Loading (from Chapter 3) ---
# In a real app, handle potential loading errors
model = joblib.load("model.joblib")
# --- Prediction Endpoint (from Chapter 3) ---
@app.post("/predict", response_model=PredictionOutput)
async def make_prediction(input_data: InputFeatures):
"""
Accepts input features and returns a prediction.
"""
# Convert Pydantic model to format expected by the model
# This is simplified; real preprocessing might be more complex
features = [[input_data.feature1, input_data.feature2]]
prediction_result = model.predict(features)
return PredictionOutput(prediction=prediction_result[0])
# --- Root Endpoint (Optional) ---
@app.get("/")
async def read_root():
return {"message": "Prediction service is running"}
# To run (using uvicorn): uvicorn main_before_refactor:app --reload
This works for simple cases, but as we add more endpoints (e.g., for model info, batch predictions, different model versions), this single file becomes unwieldy.
Let's structure the project using APIRouter
.
Create a Project Structure: Organize your files like this:
your_project/
├── app/
│ ├── __init__.py
│ ├── main.py # Main application setup
│ ├── routers/
│ │ ├── __init__.py
│ │ └── predictions.py # Prediction-related routes
│ ├── models/
│ │ ├── __init__.py
│ │ └── schemas.py # Pydantic models
│ └── core/
│ ├── __init__.py
│ └── config.py # Configuration (optional for now)
├── tests/
│ ├── __init__.py
│ └── test_predictions.py # Tests for prediction routes
├── model.joblib # Your serialized model
└── requirements.txt # Project dependencies
Define Pydantic Models: Move the Pydantic models to app/models/schemas.py
:
# app/models/schemas.py
from pydantic import BaseModel
class InputFeatures(BaseModel):
feature1: float
feature2: float
# ... other features
class PredictionOutput(BaseModel):
prediction: float # Or appropriate type
Create the Prediction Router: Move the prediction logic into app/routers/predictions.py
. Notice we import APIRouter
and use router
instead of app
as the decorator. We also adjust import paths.
# app/routers/predictions.py
from fastapi import APIRouter
import joblib
from app.models.schemas import InputFeatures, PredictionOutput
# Assume model path is configured or known
MODEL_PATH = "model.joblib"
model = joblib.load(MODEL_PATH)
router = APIRouter(
prefix="/predict", # All routes in this router will start with /predict
tags=["predictions"] # Group endpoints in API docs
)
@router.post("/", response_model=PredictionOutput) # Path is now relative to prefix
async def make_prediction(input_data: InputFeatures):
"""
Accepts input features and returns a prediction.
(Logic remains the same as before)
"""
features = [[input_data.feature1, input_data.feature2]]
prediction_result = model.predict(features)
return PredictionOutput(prediction=prediction_result[0])
# You could add other prediction-related endpoints here later,
# e.g., @router.post("/batch", ...)
Update the Main Application: Modify app/main.py
to create the main FastAPI
instance and include the router.
# app/main.py
from fastapi import FastAPI
from app.routers import predictions # Import the router module
app = FastAPI(title="Refactored ML Prediction Service")
# Include the router from predictions.py
app.include_router(predictions.router)
@app.get("/")
async def read_root():
return {"message": "Prediction service is running"}
# To run: uvicorn app.main:app --reload
Simplified diagram illustrating the refactored project structure and interaction between
main.py
, the prediction router, and Pydantic models.
Now, our prediction logic is neatly contained within app/routers/predictions.py
, and main.py
is cleaner, focusing on application setup and routing.
With the structure in place, let's write tests. We'll use pytest
and FastAPI's TestClient
.
Install Pytest: If you haven't already, install pytest
:
pip install pytest
Create Test File: Create tests/test_predictions.py
.
Write Tests:
# tests/test_predictions.py
from fastapi.testclient import TestClient
from app.main import app # Import the FastAPI app instance
from app.models.schemas import InputFeatures # Import for type hints if needed
# Create a TestClient instance using our FastAPI app
client = TestClient(app)
def test_read_root():
"""Test the root endpoint."""
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Prediction service is running"}
def test_make_prediction_success():
"""Test the prediction endpoint with valid input."""
# Define valid input data matching InputFeatures schema
valid_input = {"feature1": 5.1, "feature2": 3.5}
# Make a POST request to the /predict/ endpoint
response = client.post("/predict/", json=valid_input)
# Assert the request was successful (HTTP 200 OK)
assert response.status_code == 200
# Assert the response body structure matches PredictionOutput
response_data = response.json()
assert "prediction" in response_data
# Optionally, assert the type of the prediction
assert isinstance(response_data["prediction"], float)
# Note: Asserting the exact prediction value depends on your model
# and might require a fixed test dataset or mocking the model.
# For simplicity here, we focus on structure and status.
def test_make_prediction_invalid_input_type():
"""Test the prediction endpoint with incorrect input data type."""
# Send data where a feature is a string instead of a float
invalid_input = {"feature1": "wrong_type", "feature2": 3.5}
response = client.post("/predict/", json=invalid_input)
# FastAPI/Pydantic automatically handles validation errors
# Expect HTTP 422 Unprocessable Entity
assert response.status_code == 422
# Check if the response body contains validation error details
response_data = response.json()
assert "detail" in response_data
# You can add more specific checks on the error message if needed
# e.g., assert "feature1" in str(response_data["detail"])
def test_make_prediction_missing_input_feature():
"""Test the prediction endpoint with missing input data."""
# Send data missing 'feature2'
missing_input = {"feature1": 5.1}
response = client.post("/predict/", json=missing_input)
# Expect HTTP 422 Unprocessable Entity
assert response.status_code == 422
response_data = response.json()
assert "detail" in response_data
# e.g., assert "feature2" in str(response_data["detail"])
# e.g., assert "field required" in str(response_data["detail"])
Run Tests: Navigate to your project's root directory (your_project/
) in the terminal and run pytest
:
pytest
Pytest will discover and run the tests in the tests
directory. You should see output indicating whether the tests passed or failed.
This hands-on exercise demonstrates how to apply the structuring principles using APIRouter
to organize your prediction service and how to use TestClient
to write effective tests. This approach significantly improves maintainability and ensures your API behaves as expected, even after refactoring or adding new features. As your application grows, this separation and testing become increasingly valuable.
© 2025 ApX Machine Learning