Why FastAPI?

FastAPI has become the go-to framework for deploying ML models as APIs. It offers high performance (on par with NodeJS and Go), automatic API documentation, and built-in data validation.

  • Fast: One of the fastest Python frameworks available
  • Modern: Uses Python type hints for validation
  • Auto-docs: Swagger UI and ReDoc out of the box
  • Async: Native async/await support

Getting Started

# Install FastAPI
pip install fastapi uvicorn

# main.py
from fastapi import FastAPI

app = FastAPI(
    title="ML Model API",
    description="API for ML predictions",
    version="1.0.0"
)

@app.get("/")
def read_root():
    return {"message": "Welcome to ML API"}

@app.get("/health")
def health_check():
    return {"status": "healthy"}

# Run with: uvicorn main:app --reload
# Docs at: http://localhost:8000/docs

Request/Response Models with Pydantic

from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import List, Optional

app = FastAPI()

# Request model
class PredictionRequest(BaseModel):
    features: List[float] = Field(..., min_items=1, description="Input features")
    model_version: Optional[str] = "v1"

    class Config:
        schema_extra = {
            "example": {
                "features": [5.1, 3.5, 1.4, 0.2],
                "model_version": "v1"
            }
        }

# Response model
class PredictionResponse(BaseModel):
    prediction: int
    probability: float
    class_name: str

@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
    # Your model inference here
    prediction = 0
    probability = 0.95
    class_name = "setosa"

    return PredictionResponse(
        prediction=prediction,
        probability=probability,
        class_name=class_name
    )

Deploying a Scikit-learn Model

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
from typing import List

app = FastAPI()

# Load model at startup
model = None

@app.on_event("startup")
def load_model():
    global model
    model = joblib.load("model.joblib")
    print("Model loaded successfully")

class Features(BaseModel):
    data: List[List[float]]

class Prediction(BaseModel):
    predictions: List[int]
    probabilities: List[List[float]]

@app.post("/predict", response_model=Prediction)
def predict(features: Features):
    try:
        X = np.array(features.data)
        predictions = model.predict(X).tolist()
        probabilities = model.predict_proba(X).tolist()

        return Prediction(
            predictions=predictions,
            probabilities=probabilities
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/model-info")
def model_info():
    return {
        "model_type": type(model).__name__,
        "n_features": model.n_features_in_,
        "classes": model.classes_.tolist()
    }

Deploying a PyTorch Model

from fastapi import FastAPI, File, UploadFile
from PIL import Image
import torch
import torchvision.transforms as transforms
import io

app = FastAPI()

# Load model
model = torch.load("model.pth", map_location="cpu")
model.eval()

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

class_names = ["cat", "dog", "bird"]

@app.post("/predict/image")
async def predict_image(file: UploadFile = File(...)):
    # Read and preprocess image
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    input_tensor = transform(image).unsqueeze(0)

    # Inference
    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_idx = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0][predicted_idx].item()

    return {
        "prediction": class_names[predicted_idx],
        "confidence": confidence,
        "all_probabilities": {
            name: prob.item()
            for name, prob in zip(class_names, probabilities[0])
        }
    }

Async for Better Performance

from fastapi import FastAPI
import asyncio
import httpx

app = FastAPI()

# Async endpoint for I/O-bound operations
@app.get("/async-predict")
async def async_predict(data: str):
    # Simulate async model call or external API
    await asyncio.sleep(0.1)
    return {"result": "prediction"}

# Background tasks
from fastapi import BackgroundTasks

def log_prediction(prediction: dict):
    # Log to database or file
    with open("predictions.log", "a") as f:
        f.write(f"{prediction}\n")

@app.post("/predict-with-logging")
async def predict_with_logging(
    features: Features,
    background_tasks: BackgroundTasks
):
    prediction = model.predict(features.data)

    # Log in background (non-blocking)
    background_tasks.add_task(log_prediction, {"input": features.data, "output": prediction})

    return {"prediction": prediction}

Error Handling

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse

app = FastAPI()

# Custom exception
class ModelNotLoadedError(Exception):
    pass

@app.exception_handler(ModelNotLoadedError)
async def model_not_loaded_handler(request: Request, exc: ModelNotLoadedError):
    return JSONResponse(
        status_code=503,
        content={"error": "Model is not loaded", "detail": str(exc)}
    )

# Validation error handling
from fastapi.exceptions import RequestValidationError

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
    return JSONResponse(
        status_code=422,
        content={
            "error": "Validation Error",
            "details": exc.errors()
        }
    )

@app.post("/predict")
def predict(features: Features):
    if model is None:
        raise ModelNotLoadedError("Model failed to load at startup")

    try:
        result = model.predict(features.data)
        return {"prediction": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")

API Security

from fastapi import FastAPI, Depends, HTTPException, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, APIKeyHeader
import jwt

app = FastAPI()

# API Key authentication
api_key_header = APIKeyHeader(name="X-API-Key")

async def verify_api_key(api_key: str = Security(api_key_header)):
    if api_key != "your-secret-key":
        raise HTTPException(status_code=403, detail="Invalid API key")
    return api_key

@app.post("/predict", dependencies=[Depends(verify_api_key)])
def predict(features: Features):
    return {"prediction": model.predict(features.data)}

# JWT authentication
security = HTTPBearer()

def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
    try:
        payload = jwt.decode(credentials.credentials, "secret", algorithms=["HS256"])
        return payload
    except jwt.ExpiredSignatureError:
        raise HTTPException(status_code=401, detail="Token expired")
    except jwt.InvalidTokenError:
        raise HTTPException(status_code=401, detail="Invalid token")

@app.get("/protected", dependencies=[Depends(verify_token)])
def protected_route():
    return {"message": "Access granted"}

Rate Limiting

from fastapi import FastAPI, Request, HTTPException
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter

@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
    return JSONResponse(
        status_code=429,
        content={"error": "Rate limit exceeded", "retry_after": exc.detail}
    )

@app.post("/predict")
@limiter.limit("10/minute")  # 10 requests per minute
async def predict(request: Request, features: Features):
    return {"prediction": model.predict(features.data)}

Production Deployment

# Dockerfile
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

# Run with Gunicorn + Uvicorn workers
CMD ["gunicorn", "main:app", "-w", "4", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000"]

# docker-compose.yml
version: '3.8'
services:
  api:
    build: .
    ports:
      - "8000:8000"
    environment:
      - MODEL_PATH=/app/models/model.joblib
    volumes:
      - ./models:/app/models
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
      interval: 30s
      timeout: 10s
      retries: 3

# requirements.txt
fastapi==0.104.1
uvicorn[standard]==0.24.0
gunicorn==21.2.0
pydantic==2.5.2
joblib==1.3.2
numpy==1.26.2
scikit-learn==1.3.2

Deploy ML Models Like a Pro

Our Data Science program covers MLOps and deployment, including FastAPI, Docker, and cloud deployment. Build production-ready ML systems.

Explore Data Science Program

Related Articles