Once you've trained a machine learning model, it exists only in your computer's memory during that active session. To use this trained model later, especially within a separate application like a web API, you need a way to save its learned state and then load it back when needed. This process of converting the in-memory model object into a format that can be stored on disk or transmitted is called serialization, and the reverse process of reconstructing the object from the stored format is called deserialization.
Think of it like saving a document you're working on. You save it to a file (serialize) so you can close the application and open the exact same document later (deserialize) without losing your work. For ML models, "saving the work" means preserving the learned parameters, structure, and any other necessary information captured during training.
Without serialization, you would need to retrain your model every single time your API server starts, which is computationally expensive and impractical for most real-world applications. Serialization allows you to train a model once, save it as a file (often called a model artifact), and then simply load this file whenever your FastAPI application needs to make predictions.
Python offers several ways to serialize objects. For machine learning models, two libraries are particularly common:
pickle
: Python's built-in module for object serialization. It can serialize almost any Python object, including complex objects like trained scikit-learn models.joblib
: A library that provides utilities for pipelining Python jobs. It includes replacements for pickle
(joblib.dump
and joblib.load
) that are often more efficient for objects containing large NumPy arrays, which are very common in machine learning models, particularly those from scikit-learn.pickle
Let's see how you might save a simple scikit-learn model using pickle
. Assume you have a trained model object named model
:
import pickle
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
# --- Training a Dummy Model (Illustrative) ---
X, y = make_classification(n_samples=100, n_features=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = LogisticRegression()
model.fit(X_train, y_train)
# --- Model is now 'trained' ---
# Define the filename for the saved model
model_filename = 'logistic_regression_model.pkl'
# Serialize (save) the model to a file
# 'wb' mode opens the file for writing in binary mode
with open(model_filename, 'wb') as file:
pickle.dump(model, file)
print(f"Model saved to {model_filename}")
# --- Later, in a different script or session ---
# Deserialize (load) the model from the file
# 'rb' mode opens the file for reading in binary mode
try:
with open(model_filename, 'rb') as file:
loaded_model = pickle.load(file)
print(f"Model loaded from {model_filename}")
# Now you can use loaded_model to make predictions
# Example: predictions = loaded_model.predict(X_test)
except FileNotFoundError:
print(f"Error: Model file '{model_filename}' not found.")
except Exception as e:
print(f"Error loading model: {e}")
Security Note:
pickle
is powerful but has a significant security implication. Deserializing a pickle file involves executing code embedded within the file. Never load a pickle file from an untrusted or unauthenticated source, as it could contain malicious code.joblib
shares similar security considerations.
joblib
joblib
is often preferred for scikit-learn models because it can handle large NumPy arrays more efficiently, potentially resulting in smaller file sizes and faster load times compared to pickle
. The interface is very similar:
import joblib
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
# --- Training a Dummy Model (Illustrative) ---
X, y = make_classification(n_samples=100, n_features=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = LogisticRegression()
model.fit(X_train, y_train)
# --- Model is now 'trained' ---
# Define the filename for the saved model
model_filename_joblib = 'logistic_regression_model.joblib'
# Serialize (save) the model using joblib
joblib.dump(model, model_filename_joblib)
print(f"Model saved to {model_filename_joblib}")
# --- Later, in a different script or session ---
# Deserialize (load) the model using joblib
try:
loaded_model_joblib = joblib.load(model_filename_joblib)
print(f"Model loaded from {model_filename_joblib}")
# Now you can use loaded_model_joblib to make predictions
# Example: predictions = loaded_model_joblib.predict(X_test)
except FileNotFoundError:
print(f"Error: Model file '{model_filename_joblib}' not found.")
except Exception as e:
print(f"Error loading model: {e}")
For most scikit-learn use cases, joblib
is a recommended choice.
While pickle
and joblib
are general Python serialization tools, many machine learning libraries provide their own dedicated functions or formats for saving and loading models.
.h5
(HDF5) files via model.save()
and tf.keras.models.load_model()
. These formats store not just the model weights but also the model architecture and training configuration..pt
or .pth
file extensions. PyTorch allows saving the entire model (torch.save(model, PATH)
) or just the learned parameters (the state dictionary - torch.save(model.state_dict(), PATH)
), which is often preferred for flexibility. Loading is done via torch.load()
.When using these frameworks, it's generally best practice to use their native saving/loading mechanisms, as they are optimized for the specific structures and requirements of those libraries and often handle compatibility across versions more gracefully.
Pipeline
object that includes both the preprocessing steps and the model, ensuring the exact same transformations are applied during prediction as during training. Alternatively, preprocessing logic might be reimplemented directly within your FastAPI application code.Successfully serializing your model is the first essential step. The next step, covered in the following section, is to load this serialized artifact into your running FastAPI application so it's ready to serve predictions through your API endpoints.
© 2025 ApX Machine Learning