Testing endpoints that perform machine learning predictions requires specific strategies beyond standard API logic validation. While the previous section introduced TestClient
for general endpoint testing, prediction endpoints involve external dependencies like loaded models and potentially complex inference logic, making direct testing sometimes impractical or slow. This section focuses on techniques to effectively unit and integration test your ML prediction endpoints.
The primary goal when testing prediction endpoints is not typically to re-validate the model's accuracy (which is usually handled during the ML training and evaluation phases) but rather to ensure the API surrounding the model works correctly. We want to verify:
As with other endpoints, FastAPI's TestClient
is the primary tool for testing prediction endpoints. You simulate HTTP requests (e.g., POST requests with prediction input) and assert the expected HTTP status codes and response bodies.
# Example structure (assuming pytest and TestClient fixture 'client')
from fastapi import status
from pydantic import BaseModel
# Assume these are defined elsewhere
# from my_app.schemas import PredictionInput, PredictionOutput
# from my_app.main import app
class PredictionInput(BaseModel):
feature1: float
feature2: str
class PredictionOutput(BaseModel):
prediction: float
probability: float | None = None
# In your test file (e.g., test_predictions.py)
def test_predict_endpoint_success(client):
# Define valid input data based on PredictionInput schema
input_data = {"feature1": 10.5, "feature2": "categoryA"}
response = client.post("/predict", json=input_data)
assert response.status_code == status.HTTP_200_OK
# Assuming the endpoint returns data matching PredictionOutput
response_data = response.json()
assert "prediction" in response_data
# Further assertions based on expected output structure...
# For example, check type:
assert isinstance(response_data["prediction"], float)
def test_predict_endpoint_invalid_input(client):
# Input data missing a required field
invalid_input_data = {"feature1": 10.5}
response = client.post("/predict", json=invalid_input_data)
# FastAPI/Pydantic automatically handle validation errors
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
The main challenge arises from the ML model itself. Loading and running a real model during unit tests can be slow, resource-intensive, and introduce external factors (like file paths) that complicate testing. We need ways to isolate the API logic from the actual model inference for faster, more reliable unit tests.
Mocking involves replacing the actual prediction function or model object with a substitute (a "mock") during the test run. This mock can be programmed to return predefined outputs, allowing you to test the API's behavior without executing the real model inference.
FastAPI's dependency injection system provides an elegant way to achieve this using app.dependency_overrides
. You can override the dependency that provides the model or the prediction function itself.
Let's assume your prediction endpoint uses a dependency to get the prediction result:
# In your application (e.g., my_app/predictor.py)
def get_model():
# Logic to load your actual ML model
# ...
# return loaded_model
pass # Placeholder
def perform_prediction(data: PredictionInput, model = Depends(get_model)):
# Actual prediction logic using the loaded model
# prediction_result = model.predict(processed_data)
# return {"prediction": prediction_result, "probability": 0.95} # Example
# For demonstration, return a fixed structure
return {"prediction": data.feature1 * 2, "probability": 0.9}
# In your main FastAPI file (e.g., my_app/main.py)
from fastapi import FastAPI, Depends
from .schemas import PredictionInput, PredictionOutput
from .predictor import perform_prediction
app = FastAPI()
@app.post("/predict", response_model=PredictionOutput)
async def predict(data: PredictionInput, result: dict = Depends(perform_prediction)):
# The dependency injection handles calling perform_prediction
# Note: Changed perform_prediction to return a dict directly for simplicity here
# A more robust approach might involve a class-based dependency
return result
Now, in your test file, you can override the perform_prediction
dependency:
# In your test file (e.g., test_predictions.py)
from fastapi.testclient import TestClient
from my_app.main import app # Import your FastAPI app instance
from my_app.predictor import perform_prediction # Import the original dependency
from my_app.schemas import PredictionInput, PredictionOutput
# Create a mock prediction function for testing
async def mock_perform_prediction(data: PredictionInput):
# Simulate prediction logic - return a fixed, known result
# You can add assertions here about 'data' if needed
print(f"Mock prediction called with: {data}")
return {"prediction": 123.45, "probability": 0.88}
# Use FastAPI's dependency override feature
app.dependency_overrides[perform_prediction] = mock_perform_prediction
client = TestClient(app) # Create TestClient *after* overriding
def test_predict_with_mock(client):
input_data = {"feature1": 10.5, "feature2": "categoryA"}
response = client.post("/predict", json=input_data)
assert response.status_code == 200
response_data = response.json()
# Assert against the output defined in the mock function
assert response_data["prediction"] == 123.45
assert response_data["probability"] == 0.88
# Remember to clear the override if other tests need the original dependency
# This is often handled better with pytest fixtures (see below)
def teardown_function(): # Example using pytest teardown
app.dependency_overrides.clear()
Using pytest
Fixtures for Cleaner Overrides:
pytest
fixtures provide a cleaner way to manage setup and teardown, including dependency overrides:
# In conftest.py or your test file
import pytest
from fastapi.testclient import TestClient
from my_app.main import app
from my_app.predictor import perform_prediction
from my_app.schemas import PredictionInput
@pytest.fixture(scope="function") # Scope can be adjusted
def client_with_mock_predictor():
# Define the mock function inside the fixture
async def mock_perform_prediction_fixture(data: PredictionInput):
return {"prediction": 123.45, "probability": 0.88}
# Apply the override
app.dependency_overrides[perform_prediction] = mock_perform_prediction_fixture
# Yield the TestClient
yield TestClient(app)
# Teardown: Clear the override after the test using this fixture finishes
app.dependency_overrides.clear()
# In your test file (e.g., test_predictions.py)
def test_predict_with_fixture(client_with_mock_predictor): # Use the fixture
input_data = {"feature1": 10.5, "feature2": "categoryA"}
# Use the client provided by the fixture
response = client_with_mock_predictor.post("/predict", json=input_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["prediction"] == 123.45
assert response_data["probability"] == 0.88
Pros of Mocking:
Cons of Mocking:
An alternative to mocking is to use a very simple, fast "dummy" model during tests. This dummy model should mimic the interface (e.g., predict
method) of your real model but perform a trivial operation.
You can achieve this using dependency injection, similar to mocking, but instead of mocking the prediction function, you might override the dependency that loads the model (get_model
in our earlier example).
# In your test setup (e.g., conftest.py or test file)
class DummyModel:
"""A simple model stand-in for testing."""
def predict(self, input_data):
# Simple logic, e.g., return a fixed value or based on input length
print("DummyModel predict called")
return [sum(input_data.values())] # Example dummy prediction
def predict_proba(self, input_data):
# Dummy probabilities
return [[0.1, 0.9]] # Example
def get_dummy_model():
print("Providing DummyModel")
return DummyModel()
# In your pytest fixture or test setup
@pytest.fixture(scope="function")
def client_with_dummy_model():
# Assume get_model is the dependency used to load the model
# from my_app.predictor import get_model # Import the original dependency
# app.dependency_overrides[get_model] = get_dummy_model # Override it
# If perform_prediction directly uses the model:
# You might need to structure your dependency differently,
# e.g., have perform_prediction accept the model via Depends
# async def perform_prediction(data: Input, model = Depends(get_model)): ...
# For demonstration, let's assume we adjust perform_prediction to use the model from get_model
# This part requires adapting your actual app structure
# app.dependency_overrides[get_model] = get_dummy_model # Example override
yield TestClient(app) # Assuming override is applied correctly
app.dependency_overrides.clear() # Cleanup
# In your test file
# def test_predict_with_dummy_model(client_with_dummy_model):
# # ... perform test using the client ...
# # Assertions will depend on the dummy model's behavior
Note: The exact implementation depends heavily on how your model is loaded and accessed within your endpoint logic. Ensure your dependencies are structured to allow overriding the model provider.
Pros of Dummy Model:
Cons of Dummy Model:
Regardless of the strategy (mocking or dummy model), structure your tests to cover various scenarios:
response_model
.Testing prediction endpoints effectively involves isolating the API logic from the complexities of the ML model itself during unit tests. Using TestClient
combined with dependency overrides to inject mocks or dummy models provides a robust way to ensure your API behaves correctly, handles inputs and outputs as expected, and integrates properly with the prediction mechanism, paving the way for more reliable ML deployments. Remember to complement these unit/integration tests with separate, dedicated tests for your model's performance and accuracy within your ML workflow.
© 2025 ApX Machine Learning