Why FastAPI?
FastAPI has become the go-to framework for deploying ML models as APIs. It offers high performance (on par with NodeJS and Go), automatic API documentation, and built-in data validation.
- Fast: One of the fastest Python frameworks available
- Modern: Uses Python type hints for validation
- Auto-docs: Swagger UI and ReDoc out of the box
- Async: Native async/await support
Getting Started
# Install FastAPI
pip install fastapi uvicorn
# main.py
from fastapi import FastAPI
app = FastAPI(
title="ML Model API",
description="API for ML predictions",
version="1.0.0"
)
@app.get("/")
def read_root():
return {"message": "Welcome to ML API"}
@app.get("/health")
def health_check():
return {"status": "healthy"}
# Run with: uvicorn main:app --reload
# Docs at: http://localhost:8000/docs
Request/Response Models with Pydantic
from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import List, Optional
app = FastAPI()
# Request model
class PredictionRequest(BaseModel):
features: List[float] = Field(..., min_items=1, description="Input features")
model_version: Optional[str] = "v1"
class Config:
schema_extra = {
"example": {
"features": [5.1, 3.5, 1.4, 0.2],
"model_version": "v1"
}
}
# Response model
class PredictionResponse(BaseModel):
prediction: int
probability: float
class_name: str
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
# Your model inference here
prediction = 0
probability = 0.95
class_name = "setosa"
return PredictionResponse(
prediction=prediction,
probability=probability,
class_name=class_name
)
Deploying a Scikit-learn Model
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
from typing import List
app = FastAPI()
# Load model at startup
model = None
@app.on_event("startup")
def load_model():
global model
model = joblib.load("model.joblib")
print("Model loaded successfully")
class Features(BaseModel):
data: List[List[float]]
class Prediction(BaseModel):
predictions: List[int]
probabilities: List[List[float]]
@app.post("/predict", response_model=Prediction)
def predict(features: Features):
try:
X = np.array(features.data)
predictions = model.predict(X).tolist()
probabilities = model.predict_proba(X).tolist()
return Prediction(
predictions=predictions,
probabilities=probabilities
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/model-info")
def model_info():
return {
"model_type": type(model).__name__,
"n_features": model.n_features_in_,
"classes": model.classes_.tolist()
}
Deploying a PyTorch Model
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import torch
import torchvision.transforms as transforms
import io
app = FastAPI()
# Load model
model = torch.load("model.pth", map_location="cpu")
model.eval()
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class_names = ["cat", "dog", "bird"]
@app.post("/predict/image")
async def predict_image(file: UploadFile = File(...)):
# Read and preprocess image
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
input_tensor = transform(image).unsqueeze(0)
# Inference
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_idx = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_idx].item()
return {
"prediction": class_names[predicted_idx],
"confidence": confidence,
"all_probabilities": {
name: prob.item()
for name, prob in zip(class_names, probabilities[0])
}
}
Async for Better Performance
from fastapi import FastAPI
import asyncio
import httpx
app = FastAPI()
# Async endpoint for I/O-bound operations
@app.get("/async-predict")
async def async_predict(data: str):
# Simulate async model call or external API
await asyncio.sleep(0.1)
return {"result": "prediction"}
# Background tasks
from fastapi import BackgroundTasks
def log_prediction(prediction: dict):
# Log to database or file
with open("predictions.log", "a") as f:
f.write(f"{prediction}\n")
@app.post("/predict-with-logging")
async def predict_with_logging(
features: Features,
background_tasks: BackgroundTasks
):
prediction = model.predict(features.data)
# Log in background (non-blocking)
background_tasks.add_task(log_prediction, {"input": features.data, "output": prediction})
return {"prediction": prediction}
Error Handling
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
app = FastAPI()
# Custom exception
class ModelNotLoadedError(Exception):
pass
@app.exception_handler(ModelNotLoadedError)
async def model_not_loaded_handler(request: Request, exc: ModelNotLoadedError):
return JSONResponse(
status_code=503,
content={"error": "Model is not loaded", "detail": str(exc)}
)
# Validation error handling
from fastapi.exceptions import RequestValidationError
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=422,
content={
"error": "Validation Error",
"details": exc.errors()
}
)
@app.post("/predict")
def predict(features: Features):
if model is None:
raise ModelNotLoadedError("Model failed to load at startup")
try:
result = model.predict(features.data)
return {"prediction": result}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
API Security
from fastapi import FastAPI, Depends, HTTPException, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, APIKeyHeader
import jwt
app = FastAPI()
# API Key authentication
api_key_header = APIKeyHeader(name="X-API-Key")
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key != "your-secret-key":
raise HTTPException(status_code=403, detail="Invalid API key")
return api_key
@app.post("/predict", dependencies=[Depends(verify_api_key)])
def predict(features: Features):
return {"prediction": model.predict(features.data)}
# JWT authentication
security = HTTPBearer()
def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
try:
payload = jwt.decode(credentials.credentials, "secret", algorithms=["HS256"])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
@app.get("/protected", dependencies=[Depends(verify_token)])
def protected_route():
return {"message": "Access granted"}
Rate Limiting
from fastapi import FastAPI, Request, HTTPException
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
status_code=429,
content={"error": "Rate limit exceeded", "retry_after": exc.detail}
)
@app.post("/predict")
@limiter.limit("10/minute") # 10 requests per minute
async def predict(request: Request, features: Features):
return {"prediction": model.predict(features.data)}
Production Deployment
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
# Run with Gunicorn + Uvicorn workers
CMD ["gunicorn", "main:app", "-w", "4", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000"]
# docker-compose.yml
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- MODEL_PATH=/app/models/model.joblib
volumes:
- ./models:/app/models
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
# requirements.txt
fastapi==0.104.1
uvicorn[standard]==0.24.0
gunicorn==21.2.0
pydantic==2.5.2
joblib==1.3.2
numpy==1.26.2
scikit-learn==1.3.2
Deploy ML Models Like a Pro
Our Data Science program covers MLOps and deployment, including FastAPI, Docker, and cloud deployment. Build production-ready ML systems.
Explore Data Science Program