FastAPI Dependency Injection
FastAPI依赖注入
Master FastAPI's dependency injection system for building modular,
testable APIs with reusable dependencies.
掌握FastAPI的依赖注入系统,构建具备可复用依赖项的模块化、可测试API。
Simple dependency injection patterns in FastAPI.
python
from fastapi import Depends, FastAPI
app = FastAPI()
def get_db():
db = Database()
try:
yield db
finally:
db.close()
@app.get('/users')
async def get_users(db = Depends(get_db)):
return await db.query('SELECT * FROM users')
FastAPI中的简单依赖注入模式。
python
from fastapi import Depends, FastAPI
app = FastAPI()
def get_db():
db = Database()
try:
yield db
finally:
db.close()
@app.get('/users')
async def get_users(db = Depends(get_db)):
return await db.query('SELECT * FROM users')
Simple function dependency
Simple function dependency
def common_parameters(q: str = None, skip: int = 0, limit: int = 100):
return {'q': q, 'skip': skip, 'limit': limit}
@app.get('/items')
async def read_items(commons: dict = Depends(common_parameters)):
return commons
def common_parameters(q: str = None, skip: int = 0, limit: int = 100):
return {'q': q, 'skip': skip, 'limit': limit}
@app.get('/items')
async def read_items(commons: dict = Depends(common_parameters)):
return commons
Async dependencies
Async dependencies
async def get_async_db():
db = await AsyncDatabase.connect()
try:
yield db
finally:
await db.disconnect()
@app.get('/async-users')
async def get_async_users(db = Depends(get_async_db)):
return await db.fetch_all('SELECT * FROM users')
async def get_async_db():
db = await AsyncDatabase.connect()
try:
yield db
finally:
await db.disconnect()
@app.get('/async-users')
async def get_async_users(db = Depends(get_async_db)):
return await db.fetch_all('SELECT * FROM users')
Understand dependency lifecycle and caching.
python
from fastapi import Depends
了解依赖项的生命周期与缓存机制。
python
from fastapi import Depends
Request-scoped dependency (default, cached per request)
Request-scoped dependency (default, cached per request)
def get_current_user(token: str = Depends(oauth2_scheme)):
user = decode_token(token)
return user
def get_current_user(token: str = Depends(oauth2_scheme)):
user = decode_token(token)
return user
Multiple uses in same request reuse the same instance
Multiple uses in same request reuse the same instance
@app.get('/profile')
async def get_profile(
user1 = Depends(get_current_user),
user2 = Depends(get_current_user) # Same instance as user1
):
assert user1 is user2 # True
return user1
@app.get('/profile')
async def get_profile(
user1 = Depends(get_current_user),
user2 = Depends(get_current_user) # Same instance as user1
):
assert user1 is user2 # True
return user1
No caching (use_cache=False)
No caching (use_cache=False)
def get_fresh_data(use_cache: bool = False):
return fetch_from_api()
@app.get('/data')
async def get_data(data = Depends(get_fresh_data, use_cache=False)):
return data # Fetches fresh data each time
def get_fresh_data(use_cache: bool = False):
return fetch_from_api()
@app.get('/data')
async def get_data(data = Depends(get_fresh_data, use_cache=False)):
return data # Fetches fresh data each time
Singleton pattern (application scope)
Singleton pattern (application scope)
class Settings:
def init(self):
self.app_name = "MyApp"
self.admin_email = "admin@example.com"
settings = Settings() # Created once at startup
def get_settings():
return settings
@app.get('/settings')
async def read_settings(settings: Settings = Depends(get_settings)):
return settings
class Settings:
def init(self):
self.app_name = "MyApp"
self.admin_email = "admin@example.com"
settings = Settings() # Created once at startup
def get_settings():
return settings
@app.get('/settings')
async def read_settings(settings: Settings = Depends(get_settings)):
return settings
Build complex dependency chains.
python
from fastapi import Depends, HTTPException, status
构建复杂的依赖项链。
python
from fastapi import Depends, HTTPException, status
Sub-dependency chain
Sub-dependency chain
def get_db():
db = Database()
try:
yield db
finally:
db.close()
def get_current_user(
token: str = Depends(oauth2_scheme),
db = Depends(get_db)
):
user = db.query_one('SELECT * FROM users WHERE token = ?', token)
if not user:
raise HTTPException(status_code=401, detail='Invalid token')
return user
def get_current_active_user(
current_user = Depends(get_current_user)
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail='Inactive user')
return current_user
def get_admin_user(
current_user = Depends(get_current_active_user)
):
if not current_user.is_admin:
raise HTTPException(status_code=403, detail='Not authorized')
return current_user
def get_db():
db = Database()
try:
yield db
finally:
db.close()
def get_current_user(
token: str = Depends(oauth2_scheme),
db = Depends(get_db)
):
user = db.query_one('SELECT * FROM users WHERE token = ?', token)
if not user:
raise HTTPException(status_code=401, detail='Invalid token')
return user
def get_current_active_user(
current_user = Depends(get_current_user)
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail='Inactive user')
return current_user
def get_admin_user(
current_user = Depends(get_current_active_user)
):
if not current_user.is_admin:
raise HTTPException(status_code=403, detail='Not authorized')
return current_user
Use in endpoint
Use in endpoint
@app.delete('/users/{user_id}')
async def delete_user(
user_id: int,
admin = Depends(get_admin_user),
db = Depends(get_db)
):
await db.execute('DELETE FROM users WHERE id = ?', user_id)
return {'status': 'deleted'}
@app.delete('/users/{user_id}')
async def delete_user(
user_id: int,
admin = Depends(get_admin_user),
db = Depends(get_db)
):
await db.execute('DELETE FROM users WHERE id = ?', user_id)
return {'status': 'deleted'}
Class-Based Dependencies
基于类的依赖项
Use classes for stateful dependencies.
python
from fastapi import Depends
class Database:
def __init__(self, connection_string: str):
self.connection_string = connection_string
self.connection = None
async def connect(self):
self.connection = await create_connection(self.connection_string)
return self
async def disconnect(self):
if self.connection:
await self.connection.close()
async def fetch_all(self, query: str):
return await self.connection.fetch_all(query)
使用类实现有状态的依赖项。
python
from fastapi import Depends
class Database:
def __init__(self, connection_string: str):
self.connection_string = connection_string
self.connection = None
async def connect(self):
self.connection = await create_connection(self.connection_string)
return self
async def disconnect(self):
if self.connection:
await self.connection.close()
async def fetch_all(self, query: str):
return await self.connection.fetch_all(query)
Callable class (using call)
Callable class (using call)
class DatabaseDependency:
def init(self):
self.db = None
async def __call__(self):
if not self.db:
self.db = Database('postgresql://localhost/db')
await self.db.connect()
return self.db
db_dependency = DatabaseDependency()
@app.get('/users')
async def get_users(db = Depends(db_dependency)):
return await db.fetch_all('SELECT * FROM users')
class DatabaseDependency:
def init(self):
self.db = None
async def __call__(self):
if not self.db:
self.db = Database('postgresql://localhost/db')
await self.db.connect()
return self.db
db_dependency = DatabaseDependency()
@app.get('/users')
async def get_users(db = Depends(db_dependency)):
return await db.fetch_all('SELECT * FROM users')
Class with initialization parameters
Class with initialization parameters
class Pagination:
def init(self, skip: int = 0, limit: int = 100):
self.skip = skip
self.limit = limit
@app.get('/items')
async def get_items(pagination: Pagination = Depends()):
return {'skip': pagination.skip, 'limit': pagination.limit}
class Pagination:
def init(self, skip: int = 0, limit: int = 100):
self.skip = skip
self.limit = limit
@app.get('/items')
async def get_items(pagination: Pagination = Depends()):
return {'skip': pagination.skip, 'limit': pagination.limit}
Advanced class-based dependency with configuration
Advanced class-based dependency with configuration
class ServiceClient:
def init(
self,
api_key: str,
timeout: int = 30,
max_retries: int = 3
):
self.api_key = api_key
self.timeout = timeout
self.max_retries = max_retries
self.client = None
async def __call__(self):
if not self.client:
self.client = await create_http_client(
api_key=self.api_key,
timeout=self.timeout
)
return self.client
class ServiceClient:
def init(
self,
api_key: str,
timeout: int = 30,
max_retries: int = 3
):
self.api_key = api_key
self.timeout = timeout
self.max_retries = max_retries
self.client = None
async def __call__(self):
if not self.client:
self.client = await create_http_client(
api_key=self.api_key,
timeout=self.timeout
)
return self.client
Factory function for configurable class-based dependency
Factory function for configurable class-based dependency
def create_service_dependency(api_key: str):
return ServiceClient(api_key=api_key, timeout=60)
service = create_service_dependency('my-api-key')
@app.get('/external-data')
async def get_external_data(client = Depends(service)):
return await client.fetch('/data')
def create_service_dependency(api_key: str):
return ServiceClient(api_key=api_key, timeout=60)
service = create_service_dependency('my-api-key')
@app.get('/external-data')
async def get_external_data(client = Depends(service)):
return await client.fetch('/data')
Generator Dependencies for Cleanup
用于资源清理的生成器依赖项
Use generator functions to ensure proper resource cleanup.
python
from contextlib import asynccontextmanager
from fastapi import Depends
import httpx
使用生成器函数确保资源被正确清理。
python
from contextlib import asynccontextmanager
from fastapi import Depends
import httpx
Database connection with cleanup
Database connection with cleanup
async def get_db_connection():
connection = await Database.connect('postgresql://localhost/db')
try:
yield connection
finally:
await connection.close()
print('Database connection closed')
async def get_db_connection():
connection = await Database.connect('postgresql://localhost/db')
try:
yield connection
finally:
await connection.close()
print('Database connection closed')
HTTP client with cleanup
HTTP client with cleanup
async def get_http_client():
async with httpx.AsyncClient(timeout=30.0) as client:
yield client
# Client automatically closed after yield
@app.get('/external-api')
async def call_external_api(client = Depends(get_http_client)):
response = await client.get('
https://api.example.com/data')
return response.json()
async def get_http_client():
async with httpx.AsyncClient(timeout=30.0) as client:
yield client
# Client automatically closed after yield
@app.get('/external-api')
async def call_external_api(client = Depends(get_http_client)):
response = await client.get('
https://api.example.com/data')
return response.json()
File handle with cleanup
File handle with cleanup
async def get_file_writer(filename: str):
file = await aiofiles.open(filename, mode='a')
try:
yield file
finally:
await file.close()
async def get_file_writer(filename: str):
file = await aiofiles.open(filename, mode='a')
try:
yield file
finally:
await file.close()
Multiple resources with cleanup
Multiple resources with cleanup
async def get_resources():
db = await Database.connect('postgresql://localhost/db')
cache = await Redis.connect('redis://localhost')
logger = Logger('app')
try:
yield {'db': db, 'cache': cache, 'logger': logger}
finally:
await cache.close()
await db.close()
logger.shutdown()
@app.post('/process')
async def process_data(
data: dict,
resources = Depends(get_resources)
):
db = resources['db']
cache = resources['cache']
logger = resources['logger']
logger.info('Processing data')
result = await db.save(data)
await cache.invalidate('data_cache')
return result
async def get_resources():
db = await Database.connect('postgresql://localhost/db')
cache = await Redis.connect('redis://localhost')
logger = Logger('app')
try:
yield {'db': db, 'cache': cache, 'logger': logger}
finally:
await cache.close()
await db.close()
logger.shutdown()
@app.post('/process')
async def process_data(
data: dict,
resources = Depends(get_resources)
):
db = resources['db']
cache = resources['cache']
logger = resources['logger']
logger.info('Processing data')
result = await db.save(data)
await cache.invalidate('data_cache')
return result
Context manager as dependency
Context manager as dependency
@asynccontextmanager
async def transaction_context(db = Depends(get_db)):
async with db.begin() as transaction:
try:
yield transaction
await transaction.commit()
except Exception:
await transaction.rollback()
raise
async def get_transaction(db = Depends(get_db)):
async with transaction_context(db) as txn:
yield txn
@app.post('/transfer')
async def transfer_funds(
from_account: int,
to_account: int,
amount: float,
txn = Depends(get_transaction)
):
await txn.execute(
'UPDATE accounts SET balance = balance - ? WHERE id = ?',
amount, from_account
)
await txn.execute(
'UPDATE accounts SET balance = balance + ? WHERE id = ?',
amount, to_account
)
return {'status': 'success'}
@asynccontextmanager
async def transaction_context(db = Depends(get_db)):
async with db.begin() as transaction:
try:
yield transaction
await transaction.commit()
except Exception:
await transaction.rollback()
raise
async def get_transaction(db = Depends(get_db)):
async with transaction_context(db) as txn:
yield txn
@app.post('/transfer')
async def transfer_funds(
from_account: int,
to_account: int,
amount: float,
txn = Depends(get_transaction)
):
await txn.execute(
'UPDATE accounts SET balance = balance - ? WHERE id = ?',
amount, from_account
)
await txn.execute(
'UPDATE accounts SET balance = balance + ? WHERE id = ?',
amount, to_account
)
return {'status': 'success'}
Dependency Chains and Sub-Dependencies
依赖项链与子依赖项
Build complex dependency hierarchies.
python
from fastapi import Depends, HTTPException
from typing import Optional
构建复杂的依赖项层级结构。
python
from fastapi import Depends, HTTPException
from typing import Optional
Level 1: Configuration
Level 1: Configuration
def get_config():
return {
'database_url': 'postgresql://localhost/db',
'redis_url': 'redis://localhost',
'secret_key': 'super-secret'
}
def get_config():
return {
'database_url': 'postgresql://localhost/db',
'redis_url': 'redis://localhost',
'secret_key': 'super-secret'
}
Level 2: Infrastructure (depends on config)
Level 2: Infrastructure (depends on config)
def get_db(config: dict = Depends(get_config)):
db = Database(config['database_url'])
try:
yield db
finally:
db.close()
def get_cache(config: dict = Depends(get_config)):
cache = Redis(config['redis_url'])
try:
yield cache
finally:
cache.close()
def get_db(config: dict = Depends(get_config)):
db = Database(config['database_url'])
try:
yield db
finally:
db.close()
def get_cache(config: dict = Depends(get_config)):
cache = Redis(config['redis_url'])
try:
yield cache
finally:
cache.close()
Level 3: Authentication (depends on infrastructure)
Level 3: Authentication (depends on infrastructure)
def get_token_from_header(authorization: str = Header(None)):
if not authorization:
raise HTTPException(status_code=401, detail='Missing token')
scheme, token = authorization.split()
if scheme.lower() != 'bearer':
raise HTTPException(status_code=401, detail='Invalid scheme')
return token
def verify_token(
token: str = Depends(get_token_from_header),
config: dict = Depends(get_config)
):
try:
payload = jwt.decode(
token,
config['secret_key'],
algorithms=['HS256']
)
return payload
except JWTError:
raise HTTPException(status_code=401, detail='Invalid token')
def get_current_user(
payload: dict = Depends(verify_token),
db = Depends(get_db)
):
user_id = payload.get('user_id')
user = db.query_one('SELECT * FROM users WHERE id = ?', user_id)
if not user:
raise HTTPException(status_code=404, detail='User not found')
return user
def get_token_from_header(authorization: str = Header(None)):
if not authorization:
raise HTTPException(status_code=401, detail='Missing token')
scheme, token = authorization.split()
if scheme.lower() != 'bearer':
raise HTTPException(status_code=401, detail='Invalid scheme')
return token
def verify_token(
token: str = Depends(get_token_from_header),
config: dict = Depends(get_config)
):
try:
payload = jwt.decode(
token,
config['secret_key'],
algorithms=['HS256']
)
return payload
except JWTError:
raise HTTPException(status_code=401, detail='Invalid token')
def get_current_user(
payload: dict = Depends(verify_token),
db = Depends(get_db)
):
user_id = payload.get('user_id')
user = db.query_one('SELECT * FROM users WHERE id = ?', user_id)
if not user:
raise HTTPException(status_code=404, detail='User not found')
return user
Level 4: Authorization (depends on authentication)
Level 4: Authorization (depends on authentication)
def get_active_user(user = Depends(get_current_user)):
if not user.is_active:
raise HTTPException(status_code=403, detail='Inactive user')
return user
def require_permission(permission: str):
def permission_checker(user = Depends(get_active_user)):
if permission not in user.permissions:
raise HTTPException(
status_code=403,
detail=f'Permission {permission} required'
)
return user
return permission_checker
def get_active_user(user = Depends(get_current_user)):
if not user.is_active:
raise HTTPException(status_code=403, detail='Inactive user')
return user
def require_permission(permission: str):
def permission_checker(user = Depends(get_active_user)):
if permission not in user.permissions:
raise HTTPException(
status_code=403,
detail=f'Permission {permission} required'
)
return user
return permission_checker
Level 5: Business logic (depends on authorization)
Level 5: Business logic (depends on authorization)
def get_user_service(
db = Depends(get_db),
cache = Depends(get_cache),
current_user = Depends(get_active_user)
):
return UserService(db=db, cache=cache, user=current_user)
def get_user_service(
db = Depends(get_db),
cache = Depends(get_cache),
current_user = Depends(get_active_user)
):
return UserService(db=db, cache=cache, user=current_user)
Use in endpoint
Use in endpoint
@app.post('/users/{user_id}/deactivate')
async def deactivate_user(
user_id: int,
admin = Depends(require_permission('users.deactivate')),
service = Depends(get_user_service)
):
result = await service.deactivate_user(user_id)
return result
@app.post('/users/{user_id}/deactivate')
async def deactivate_user(
user_id: int,
admin = Depends(require_permission('users.deactivate')),
service = Depends(get_user_service)
):
result = await service.deactivate_user(user_id)
return result
OAuth2 Dependencies
OAuth2依赖项
Implement OAuth2 authentication with dependencies.
python
from fastapi import Depends, HTTPException, status
from fastapi.security import (
OAuth2PasswordBearer,
OAuth2PasswordRequestForm,
HTTPBearer,
HTTPAuthorizationCredentials
)
from datetime import datetime, timedelta
from jose import JWTError, jwt
使用依赖项实现OAuth2认证。
python
from fastapi import Depends, HTTPException, status
from fastapi.security import (
OAuth2PasswordBearer,
OAuth2PasswordRequestForm,
HTTPBearer,
HTTPAuthorizationCredentials
)
from datetime import datetime, timedelta
from jose import JWTError, jwt
OAuth2 with password flow
OAuth2 with password flow
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token')
SECRET_KEY = 'your-secret-key-here'
ALGORITHM = 'HS256'
ACCESS_TOKEN_EXPIRE_MINUTES = 30
def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({'exp': expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user_from_token(
token: str = Depends(oauth2_scheme),
db = Depends(get_db)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not validate credentials',
headers={'WWW-Authenticate': 'Bearer'},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get('sub')
if user_id is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await db.get_user_by_id(user_id)
if user is None:
raise credentials_exception
return user
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token')
SECRET_KEY = 'your-secret-key-here'
ALGORITHM = 'HS256'
ACCESS_TOKEN_EXPIRE_MINUTES = 30
def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({'exp': expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user_from_token(
token: str = Depends(oauth2_scheme),
db = Depends(get_db)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not validate credentials',
headers={'WWW-Authenticate': 'Bearer'},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get('sub')
if user_id is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await db.get_user_by_id(user_id)
if user is None:
raise credentials_exception
return user
OAuth2 with bearer token
OAuth2 with bearer token
bearer_scheme = HTTPBearer()
async def verify_bearer_token(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
):
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Invalid authentication credentials'
)
bearer_scheme = HTTPBearer()
async def verify_bearer_token(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
):
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Invalid authentication credentials'
)
OAuth2 with scopes
OAuth2 with scopes
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
oauth2_scheme_scopes = OAuth2PasswordBearer(
tokenUrl='token',
scopes={
'users:read': 'Read user information',
'users:write': 'Modify user information',
'admin': 'Admin access'
}
)
async def get_current_user_with_scopes(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme_scopes),
db = Depends(get_db)
):
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = 'Bearer'
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not validate credentials',
headers={'WWW-Authenticate': authenticate_value},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get('sub')
if user_id is None:
raise credentials_exception
token_scopes = payload.get('scopes', [])
except JWTError:
raise credentials_exception
user = await db.get_user_by_id(user_id)
if user is None:
raise credentials_exception
for scope in security_scopes.scopes:
if scope not in token_scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='Not enough permissions',
headers={'WWW-Authenticate': authenticate_value},
)
return user
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
oauth2_scheme_scopes = OAuth2PasswordBearer(
tokenUrl='token',
scopes={
'users:read': 'Read user information',
'users:write': 'Modify user information',
'admin': 'Admin access'
}
)
async def get_current_user_with_scopes(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme_scopes),
db = Depends(get_db)
):
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = 'Bearer'
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not validate credentials',
headers={'WWW-Authenticate': authenticate_value},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get('sub')
if user_id is None:
raise credentials_exception
token_scopes = payload.get('scopes', [])
except JWTError:
raise credentials_exception
user = await db.get_user_by_id(user_id)
if user is None:
raise credentials_exception
for scope in security_scopes.scopes:
if scope not in token_scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='Not enough permissions',
headers={'WWW-Authenticate': authenticate_value},
)
return user
Use with scopes
Use with scopes
@app.get('/users/me', dependencies=[Security(
get_current_user_with_scopes,
scopes=['users:read']
)])
async def read_users_me(
current_user = Depends(get_current_user_with_scopes)
):
return current_user
@app.put('/users/me', dependencies=[Security(
get_current_user_with_scopes,
scopes=['users:write']
)])
async def update_user_me(
user_update: UserUpdate,
current_user = Depends(get_current_user_with_scopes),
db = Depends(get_db)
):
updated_user = await db.update_user(current_user.id, user_update)
return updated_user
@app.get('/users/me', dependencies=[Security(
get_current_user_with_scopes,
scopes=['users:read']
)])
async def read_users_me(
current_user = Depends(get_current_user_with_scopes)
):
return current_user
@app.put('/users/me', dependencies=[Security(
get_current_user_with_scopes,
scopes=['users:write']
)])
async def update_user_me(
user_update: UserUpdate,
current_user = Depends(get_current_user_with_scopes),
db = Depends(get_db)
):
updated_user = await db.update_user(current_user.id, user_update)
return updated_user
Database Dependencies
数据库依赖项
Manage database connections with dependencies.
python
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
使用依赖项管理数据库连接。
python
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
SQLAlchemy setup
SQLAlchemy setup
DATABASE_URL = 'postgresql+asyncpg://user:pass@localhost/db'
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
DATABASE_URL = 'postgresql+asyncpg://user:pass@localhost/db'
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
from sqlalchemy import select
@app.get('/users/{user_id}')
async def get_user(user_id: int, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail='User not found')
return user
from sqlalchemy import select
@app.get('/users/{user_id}')
async def get_user(user_id: int, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail='User not found')
return user
With transaction
With transaction
@app.post('/users')
async def create_user(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
user = User(**user_data.dict())
db.add(user)
await db.flush() # Get the ID before commit
return user
@app.post('/users')
async def create_user(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
user = User(**user_data.dict())
db.add(user)
await db.flush() # Get the ID before commit
return user
Authentication Dependencies
认证依赖项
Implement authentication patterns.
python
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token')
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
SECRET_KEY = 'your-secret-key'
ALGORITHM = 'HS256'
实现认证模式。
python
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token')
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
SECRET_KEY = 'your-secret-key'
ALGORITHM = 'HS256'
Token verification
Token verification
async def get_current_user(
token: str = Depends(oauth2_scheme),
db = Depends(get_db)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not validate credentials',
headers={'WWW-Authenticate': 'Bearer'},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get('sub')
if email is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await db.fetch_one('SELECT * FROM users WHERE email = ?', email)
if user is None:
raise credentials_exception
return user
async def get_current_user(
token: str = Depends(oauth2_scheme),
db = Depends(get_db)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not validate credentials',
headers={'WWW-Authenticate': 'Bearer'},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get('sub')
if email is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await db.fetch_one('SELECT * FROM users WHERE email = ?', email)
if user is None:
raise credentials_exception
return user
Active user check
Active user check
async def get_current_active_user(
current_user = Depends(get_current_user)
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail='Inactive user')
return current_user
async def get_current_active_user(
current_user = Depends(get_current_user)
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail='Inactive user')
return current_user
Role-based access
Role-based access
def require_role(required_role: str):
async def role_checker(current_user = Depends(get_current_active_user)):
if current_user.role != required_role:
raise HTTPException(
status_code=403,
detail=f'Role {required_role} required'
)
return current_user
return role_checker
@app.get('/admin')
async def admin_route(user = Depends(require_role('admin'))):
return {'message': 'Admin access granted'}
def require_role(required_role: str):
async def role_checker(current_user = Depends(get_current_active_user)):
if current_user.role != required_role:
raise HTTPException(
status_code=403,
detail=f'Role {required_role} required'
)
return current_user
return role_checker
@app.get('/admin')
async def admin_route(user = Depends(require_role('admin'))):
return {'message': 'Admin access granted'}
Caching Dependencies
缓存依赖项
Implement caching patterns with dependencies.
python
from fastapi import Depends
from functools import lru_cache
import redis.asyncio as redis
使用依赖项实现缓存模式。
python
from fastapi import Depends
from functools import lru_cache
import redis.asyncio as redis
In-memory cache
In-memory cache
@lru_cache()
def get_settings():
return Settings()
@app.get('/config')
async def get_config(settings: Settings = Depends(get_settings)):
return settings
@lru_cache()
def get_settings():
return Settings()
@app.get('/config')
async def get_config(settings: Settings = Depends(get_settings)):
return settings
Redis cache dependency
Redis cache dependency
class RedisCache:
def init(self):
self.redis = None
async def get_connection(self):
if not self.redis:
self.redis = await redis.from_url('redis://localhost')
return self.redis
async def get(self, key: str):
conn = await self.get_connection()
value = await conn.get(key)
return value.decode() if value else None
async def set(self, key: str, value: str, expire: int = 3600):
conn = await self.get_connection()
await conn.set(key, value, ex=expire)
cache = RedisCache()
async def get_cache():
return cache
@app.get('/cached-data/{key}')
async def get_cached_data(
key: str,
cache: RedisCache = Depends(get_cache)
):
value = await cache.get(key)
if value:
return {'value': value, 'cached': True}
# Fetch and cache
value = fetch_expensive_data(key)
await cache.set(key, value)
return {'value': value, 'cached': False}
class RedisCache:
def init(self):
self.redis = None
async def get_connection(self):
if not self.redis:
self.redis = await redis.from_url('redis://localhost')
return self.redis
async def get(self, key: str):
conn = await self.get_connection()
value = await conn.get(key)
return value.decode() if value else None
async def set(self, key: str, value: str, expire: int = 3600):
conn = await self.get_connection()
await conn.set(key, value, ex=expire)
cache = RedisCache()
async def get_cache():
return cache
@app.get('/cached-data/{key}')
async def get_cached_data(
key: str,
cache: RedisCache = Depends(get_cache)
):
value = await cache.get(key)
if value:
return {'value': value, 'cached': True}
# Fetch and cache
value = fetch_expensive_data(key)
await cache.set(key, value)
return {'value': value, 'cached': False}
Background Task Dependencies
后台任务依赖项
Use dependencies with background tasks.
python
from fastapi import BackgroundTasks, Depends
import asyncio
async def send_email(email: str, message: str):
# Simulate sending email
await asyncio.sleep(2)
print(f'Email sent to {email}: {message}')
def get_email_service():
# Could return configured email service
return send_email
@app.post('/users')
async def create_user(
user: UserCreate,
background_tasks: BackgroundTasks,
email_service = Depends(get_email_service),
db = Depends(get_db)
):
user_obj = await db.create_user(user)
background_tasks.add_task(
email_service,
user.email,
'Welcome to our service!'
)
return user_obj
将依赖项与后台任务结合使用。
python
from fastapi import BackgroundTasks, Depends
import asyncio
async def send_email(email: str, message: str):
# Simulate sending email
await asyncio.sleep(2)
print(f'Email sent to {email}: {message}')
def get_email_service():
# Could return configured email service
return send_email
@app.post('/users')
async def create_user(
user: UserCreate,
background_tasks: BackgroundTasks,
email_service = Depends(get_email_service),
db = Depends(get_db)
):
user_obj = await db.create_user(user)
background_tasks.add_task(
email_service,
user.email,
'Welcome to our service!'
)
return user_obj
Complex background task with dependencies
Complex background task with dependencies
class EmailService:
def init(self, db = Depends(get_db)):
self.db = db
async def send_welcome_email(self, user_id: int):
user = await self.db.get_user(user_id)
await send_email(user.email, f'Welcome {user.name}!')
@app.post('/users/v2')
async def create_user_v2(
user: UserCreate,
background_tasks: BackgroundTasks,
db = Depends(get_db)
):
user_obj = await db.create_user(user)
email_service = EmailService(db)
background_tasks.add_task(
email_service.send_welcome_email,
user_obj.id
)
return user_obj
class EmailService:
def init(self, db = Depends(get_db)):
self.db = db
async def send_welcome_email(self, user_id: int):
user = await self.db.get_user(user_id)
await send_email(user.email, f'Welcome {user.name}!')
@app.post('/users/v2')
async def create_user_v2(
user: UserCreate,
background_tasks: BackgroundTasks,
db = Depends(get_db)
):
user_obj = await db.create_user(user)
email_service = EmailService(db)
background_tasks.add_task(
email_service.send_welcome_email,
user_obj.id
)
return user_obj
Background task with cleanup
Background task with cleanup
async def process_file_with_cleanup(file_path: str):
try:
# Process file
await process_file(file_path)
finally:
# Cleanup
os.remove(file_path)
@app.post('/upload')
async def upload_file(
file: UploadFile,
background_tasks: BackgroundTasks
):
file_path = f'/tmp/{file.filename}'
with open(file_path, 'wb') as f:
f.write(await file.read())
background_tasks.add_task(process_file_with_cleanup, file_path)
return {'status': 'processing'}
async def process_file_with_cleanup(file_path: str):
try:
# Process file
await process_file(file_path)
finally:
# Cleanup
os.remove(file_path)
@app.post('/upload')
async def upload_file(
file: UploadFile,
background_tasks: BackgroundTasks
):
file_path = f'/tmp/{file.filename}'
with open(file_path, 'wb') as f:
f.write(await file.read())
background_tasks.add_task(process_file_with_cleanup, file_path)
return {'status': 'processing'}
WebSocket Dependencies
WebSocket依赖项
Use dependencies with WebSocket connections.
python
from fastapi import WebSocket, Depends, WebSocketDisconnect
from typing import List
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
def get_connection_manager():
return manager
将依赖项与WebSocket连接结合使用。
python
from fastapi import WebSocket, Depends, WebSocketDisconnect
from typing import List
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
def get_connection_manager():
return manager
WebSocket with authentication
WebSocket with authentication
async def get_ws_current_user(
websocket: WebSocket,
token: str = Query(...),
db = Depends(get_db)
):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get('sub')
user = await db.get_user_by_id(user_id)
if not user:
await websocket.close(code=1008)
raise HTTPException(status_code=401, detail='Invalid token')
return user
except JWTError:
await websocket.close(code=1008)
raise HTTPException(status_code=401, detail='Invalid token')
@app.websocket('/ws/{client_id}')
async def websocket_endpoint(
websocket: WebSocket,
client_id: int,
manager: ConnectionManager = Depends(get_connection_manager),
user = Depends(get_ws_current_user)
):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
message = f'User {user.name}: {data}'
await manager.broadcast(message)
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f'User {user.name} left the chat')
async def get_ws_current_user(
websocket: WebSocket,
token: str = Query(...),
db = Depends(get_db)
):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get('sub')
user = await db.get_user_by_id(user_id)
if not user:
await websocket.close(code=1008)
raise HTTPException(status_code=401, detail='Invalid token')
return user
except JWTError:
await websocket.close(code=1008)
raise HTTPException(status_code=401, detail='Invalid token')
@app.websocket('/ws/{client_id}')
async def websocket_endpoint(
websocket: WebSocket,
client_id: int,
manager: ConnectionManager = Depends(get_connection_manager),
user = Depends(get_ws_current_user)
):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
message = f'User {user.name}: {data}'
await manager.broadcast(message)
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f'User {user.name} left the chat')
WebSocket with database
WebSocket with database
@app.websocket('/ws/messages/{room_id}')
async def message_websocket(
websocket: WebSocket,
room_id: int,
db = Depends(get_db),
user = Depends(get_ws_current_user)
):
await websocket.accept()
# Send message history
messages = await db.get_room_messages(room_id)
await websocket.send_json(messages)
try:
while True:
data = await websocket.receive_text()
# Save to database
message = await db.create_message(
room_id=room_id,
user_id=user.id,
content=data
)
await websocket.send_json(message)
except WebSocketDisconnect:
pass
@app.websocket('/ws/messages/{room_id}')
async def message_websocket(
websocket: WebSocket,
room_id: int,
db = Depends(get_db),
user = Depends(get_ws_current_user)
):
await websocket.accept()
# Send message history
messages = await db.get_room_messages(room_id)
await websocket.send_json(messages)
try:
while True:
data = await websocket.receive_text()
# Save to database
message = await db.create_message(
room_id=room_id,
user_id=user.id,
content=data
)
await websocket.send_json(message)
except WebSocketDisconnect:
pass
Custom Dependency Providers
自定义依赖项提供器
Create custom dependency injection patterns.
python
from typing import Callable, Type, TypeVar, Generic
from fastapi import Depends
T = TypeVar('T')
创建自定义依赖注入模式。
python
from typing import Callable, Type, TypeVar, Generic
from fastapi import Depends
T = TypeVar('T')
Dependency factory
Dependency factory
class DependencyFactory(Generic[T]):
def init(self, dependency_class: Type[T], **kwargs):
self.dependency_class = dependency_class
self.kwargs = kwargs
def __call__(self) -> T:
return self.dependency_class(**self.kwargs)
class DependencyFactory(Generic[T]):
def init(self, dependency_class: Type[T], **kwargs):
self.dependency_class = dependency_class
self.kwargs = kwargs
def __call__(self) -> T:
return self.dependency_class(**self.kwargs)
Service locator pattern
Service locator pattern
class ServiceLocator:
def init(self):
self._services = {}
def register(self, name: str, service):
self._services[name] = service
def get(self, name: str):
return self._services.get(name)
def create_dependency(self, name: str):
def get_service():
service = self.get(name)
if service is None:
raise ValueError(f'Service {name} not registered')
return service
return get_service
class ServiceLocator:
def init(self):
self._services = {}
def register(self, name: str, service):
self._services[name] = service
def get(self, name: str):
return self._services.get(name)
def create_dependency(self, name: str):
def get_service():
service = self.get(name)
if service is None:
raise ValueError(f'Service {name} not registered')
return service
return get_service
Initialize service locator
Initialize service locator
locator = ServiceLocator()
locator.register('db', Database('postgresql://localhost/db'))
locator.register('cache', Redis('redis://localhost'))
locator = ServiceLocator()
locator.register('db', Database('postgresql://localhost/db'))
locator.register('cache', Redis('redis://localhost'))
Use in endpoint
Use in endpoint
get_db_from_locator = locator.create_dependency('db')
@app.get('/items')
async def get_items(db = Depends(get_db_from_locator)):
return await db.fetch_all('SELECT * FROM items')
get_db_from_locator = locator.create_dependency('db')
@app.get('/items')
async def get_items(db = Depends(get_db_from_locator)):
return await db.fetch_all('SELECT * FROM items')
Dependency provider with context
Dependency provider with context
class ContextualDependency:
def init(self, request: Request):
self.request = request
self.context = {}
def set(self, key: str, value):
self.context[key] = value
def get(self, key: str):
return self.context.get(key)
async def get_request_context(request: Request):
context = ContextualDependency(request)
context.set('request_id', str(uuid.uuid4()))
context.set('timestamp', datetime.utcnow())
return context
@app.get('/context-example')
async def context_example(context = Depends(get_request_context)):
return {
'request_id': context.get('request_id'),
'timestamp': context.get('timestamp')
}
class ContextualDependency:
def init(self, request: Request):
self.request = request
self.context = {}
def set(self, key: str, value):
self.context[key] = value
def get(self, key: str):
return self.context.get(key)
async def get_request_context(request: Request):
context = ContextualDependency(request)
context.set('request_id', str(uuid.uuid4()))
context.set('timestamp', datetime.utcnow())
return context
@app.get('/context-example')
async def context_example(context = Depends(get_request_context)):
return {
'request_id': context.get('request_id'),
'timestamp': context.get('timestamp')
}
Lazy dependency loading
Lazy dependency loading
class LazyDependency:
def init(self, factory: Callable):
self.factory = factory
self._instance = None
def __call__(self):
if self._instance is None:
self._instance = self.factory()
return self._instance
def create_expensive_service():
print('Creating expensive service...')
return ExpensiveService()
expensive_service = LazyDependency(create_expensive_service)
@app.get('/expensive')
async def use_expensive(service = Depends(expensive_service)):
return service.do_work()
class LazyDependency:
def init(self, factory: Callable):
self.factory = factory
self._instance = None
def __call__(self):
if self._instance is None:
self._instance = self.factory()
return self._instance
def create_expensive_service():
print('Creating expensive service...')
return ExpensiveService()
expensive_service = LazyDependency(create_expensive_service)
@app.get('/expensive')
async def use_expensive(service = Depends(expensive_service)):
return service.do_work()
Scoped Dependencies Per Request
每个请求的作用域依赖项
Manage request-scoped state and lifecycle.
python
from contextvars import ContextVar
from fastapi import Depends, Request
管理请求作用域的状态与生命周期。
python
from contextvars import ContextVar
from fastapi import Depends, Request
Request-scoped storage using context variables
Request-scoped storage using context variables
request_id_var: ContextVar[str] = ContextVar('request_id')
async def set_request_id(request: Request):
request_id = request.headers.get('X-Request-ID', str(uuid.uuid4()))
request_id_var.set(request_id)
return request_id
async def get_request_id():
return request_id_var.get()
request_id_var: ContextVar[str] = ContextVar('request_id')
async def set_request_id(request: Request):
request_id = request.headers.get('X-Request-ID', str(uuid.uuid4()))
request_id_var.set(request_id)
return request_id
async def get_request_id():
return request_id_var.get()
Request-scoped logger
Request-scoped logger
class RequestLogger:
def init(self, request_id: str = Depends(get_request_id)):
self.request_id = request_id
def info(self, message: str):
print(f'[{self.request_id}] INFO: {message}')
def error(self, message: str):
print(f'[{self.request_id}] ERROR: {message}')
@app.get('/scoped-logging')
async def scoped_logging(
request_id: str = Depends(set_request_id),
logger: RequestLogger = Depends()
):
logger.info('Processing request')
# Do work
logger.info('Request completed')
return {'request_id': request_id}
class RequestLogger:
def init(self, request_id: str = Depends(get_request_id)):
self.request_id = request_id
def info(self, message: str):
print(f'[{self.request_id}] INFO: {message}')
def error(self, message: str):
print(f'[{self.request_id}] ERROR: {message}')
@app.get('/scoped-logging')
async def scoped_logging(
request_id: str = Depends(set_request_id),
logger: RequestLogger = Depends()
):
logger.info('Processing request')
# Do work
logger.info('Request completed')
return {'request_id': request_id}
Request-scoped unit of work pattern
Request-scoped unit of work pattern
class UnitOfWork:
def init(self, db = Depends(get_db)):
self.db = db
self.repositories = {}
def get_repository(self, model_class):
if model_class not in self.repositories:
self.repositories[model_class] = Repository(self.db, model_class)
return self.repositories[model_class]
async def commit(self):
await self.db.commit()
async def rollback(self):
await self.db.rollback()
@app.post('/complex-transaction')
async def complex_transaction(
data: TransactionData,
uow: UnitOfWork = Depends()
):
try:
user_repo = uow.get_repository(User)
order_repo = uow.get_repository(Order)
user = await user_repo.create(data.user)
order = await order_repo.create(data.order, user_id=user.id)
await uow.commit()
return {'user': user, 'order': order}
except Exception:
await uow.rollback()
raise
class UnitOfWork:
def init(self, db = Depends(get_db)):
self.db = db
self.repositories = {}
def get_repository(self, model_class):
if model_class not in self.repositories:
self.repositories[model_class] = Repository(self.db, model_class)
return self.repositories[model_class]
async def commit(self):
await self.db.commit()
async def rollback(self):
await self.db.rollback()
@app.post('/complex-transaction')
async def complex_transaction(
data: TransactionData,
uow: UnitOfWork = Depends()
):
try:
user_repo = uow.get_repository(User)
order_repo = uow.get_repository(Order)
user = await user_repo.create(data.user)
order = await order_repo.create(data.order, user_id=user.id)
await uow.commit()
return {'user': user, 'order': order}
except Exception:
await uow.rollback()
raise
Global Dependencies with app.dependency_overrides
使用app.dependency_overrides的全局依赖项
Use global dependency management and overrides.
python
from fastapi import FastAPI, Depends
app = FastAPI()
使用全局依赖项管理与覆盖机制。
python
from fastapi import FastAPI, Depends
app = FastAPI()
Original dependencies
Original dependencies
def get_production_db():
return ProductionDatabase('postgresql://prod/db')
def get_production_cache():
return RedisCache('redis://prod')
def get_production_db():
return ProductionDatabase('postgresql://prod/db')
def get_production_cache():
return RedisCache('redis://prod')
Default app setup
Default app setup
@app.get('/data')
async def get_data(
db = Depends(get_production_db),
cache = Depends(get_production_cache)
):
cached = await cache.get('data')
if cached:
return cached
data = await db.fetch_all('SELECT * FROM data')
await cache.set('data', data)
return data
@app.get('/data')
async def get_data(
db = Depends(get_production_db),
cache = Depends(get_production_cache)
):
cached = await cache.get('data')
if cached:
return cached
data = await db.fetch_all('SELECT * FROM data')
await cache.set('data', data)
return data
Override for testing environment
Override for testing environment
if os.getenv('ENVIRONMENT') == 'test':
def get_test_db():
return TestDatabase(':memory:')
def get_test_cache():
return InMemoryCache()
app.dependency_overrides[get_production_db] = get_test_db
app.dependency_overrides[get_production_cache] = get_test_cache
if os.getenv('ENVIRONMENT') == 'test':
def get_test_db():
return TestDatabase(':memory:')
def get_test_cache():
return InMemoryCache()
app.dependency_overrides[get_production_db] = get_test_db
app.dependency_overrides[get_production_cache] = get_test_cache
Override for development
Override for development
if os.getenv('ENVIRONMENT') == 'development':
def get_dev_db():
return DevDatabase('postgresql://localhost/dev')
app.dependency_overrides[get_production_db] = get_dev_db
if os.getenv('ENVIRONMENT') == 'development':
def get_dev_db():
return DevDatabase('postgresql://localhost/dev')
app.dependency_overrides[get_production_db] = get_dev_db
Dynamic override based on request
Dynamic override based on request
async def override_db_by_tenant(request: Request):
tenant_id = request.headers.get('X-Tenant-ID')
if tenant_id:
return TenantDatabase(tenant_id)
return get_production_db()
async def override_db_by_tenant(request: Request):
tenant_id = request.headers.get('X-Tenant-ID')
if tenant_id:
return TenantDatabase(tenant_id)
return get_production_db()
Conditional override
Conditional override
def setup_dependencies(app: FastAPI, config: dict):
if config.get('use_mock_db'):
app.dependency_overrides[get_production_db] = lambda: MockDB()
if config.get('use_local_cache'):
app.dependency_overrides[get_production_cache] = lambda: LocalCache()
def setup_dependencies(app: FastAPI, config: dict):
if config.get('use_mock_db'):
app.dependency_overrides[get_production_db] = lambda: MockDB()
if config.get('use_local_cache'):
app.dependency_overrides[get_production_cache] = lambda: LocalCache()
Clear overrides
Clear overrides
def clear_overrides():
app.dependency_overrides = {}
def clear_overrides():
app.dependency_overrides = {}
Testing with Dependencies
依赖项测试
Override dependencies for testing.
python
from fastapi.testclient import TestClient
覆盖依赖项以进行测试。
python
from fastapi.testclient import TestClient
Original dependency
Original dependency
def get_db():
db = ProductionDB()
try:
yield db
finally:
db.close()
app.dependency_overrides = {}
def get_db():
db = ProductionDB()
try:
yield db
finally:
db.close()
app.dependency_overrides = {}
Test with mock database
Test with mock database
def test_get_users():
def override_get_db():
return MockDB()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
response = client.get('/users')
assert response.status_code == 200
# Cleanup
app.dependency_overrides = {}
def test_get_users():
def override_get_db():
return MockDB()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
response = client.get('/users')
assert response.status_code == 200
# Cleanup
app.dependency_overrides = {}
Pytest fixture for dependency override
Pytest fixture for dependency override
import pytest
@pytest.fixture
def client():
def override_get_db():
return MockDB()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
yield client
app.dependency_overrides = {}
def test_with_fixture(client):
response = client.get('/users')
assert response.status_code == 200
import pytest
@pytest.fixture
def client():
def override_get_db():
return MockDB()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
yield client
app.dependency_overrides = {}
def test_with_fixture(client):
response = client.get('/users')
assert response.status_code == 200
Apply dependencies to all routes.
python
from fastapi import FastAPI, Depends
将依赖项应用于所有路由。
python
from fastapi import FastAPI, Depends
Logging dependency
Logging dependency
async def log_request(request: Request):
print(f'{request.method} {request.url}')
async def log_request(request: Request):
print(f'{request.method} {request.url}')
Rate limiting
Rate limiting
async def rate_limit(request: Request):
client_ip = request.client.host
# Check rate limit
if is_rate_limited(client_ip):
raise HTTPException(status_code=429, detail='Too many requests')
async def rate_limit(request: Request):
client_ip = request.client.host
# Check rate limit
if is_rate_limited(client_ip):
raise HTTPException(status_code=429, detail='Too many requests')
Apply globally
Apply globally
app = FastAPI(dependencies=[Depends(log_request), Depends(rate_limit)])
app = FastAPI(dependencies=[Depends(log_request), Depends(rate_limit)])
Apply to router
Apply to router
router = APIRouter(dependencies=[Depends(get_current_user)])
@router.get('/protected-resource')
async def protected_route():
return {'message': 'This requires authentication'}
app.include_router(router)
router = APIRouter(dependencies=[Depends(get_current_user)])
@router.get('/protected-resource')
async def protected_route():
return {'message': 'This requires authentication'}
app.include_router(router)
When to Use This Skill
何时使用该技能
Use fastapi-dependency-injection when building modern,
production-ready applications that require
advanced patterns, best practices, and optimal performance.
在构建需要高级模式、最佳实践与最优性能的现代生产级应用时,使用fastapi-dependency-injection。
FastAPI DI Best Practices
FastAPI依赖注入最佳实践
-
Use yield for cleanup - Always use yield for resources that
need cleanup like database connections, file handles, and network
connections to ensure proper resource management
-
Leverage caching - Use dependency caching (enabled by default)
to avoid redundant work within a request; multiple uses of the same
dependency in one request share the same instance
-
Chain dependencies - Build complex dependencies from simpler
sub-dependencies to create composable, testable, and maintainable code
structures
-
Class-based for state - Use classes for dependencies that
maintain state or configuration, leveraging
method for
callable instances
-
Type hints everywhere - Always add type hints for better editor
support, automatic validation, and improved documentation generation
-
Override for testing - Use dependency_overrides to inject mocks
during testing without modifying production code
-
Global for cross-cutting - Apply common dependencies (logging,
auth, rate limiting) globally or at router level to avoid repetition
-
Async when possible - Use async dependencies for I/O operations
to maximize performance and concurrency benefits
-
Separate concerns - Keep authentication, authorization, and
business logic in separate dependencies for better testability and
reusability
-
Document dependencies - Add docstrings to explain complex
dependency chains, especially when building multi-level hierarchies
-
Use Security utilities - Leverage FastAPI's security utilities
like OAuth2PasswordBearer and HTTPBearer for authentication patterns
-
Validate early - Place validation dependencies early in the
chain to fail fast and provide clear error messages
-
Keep dependencies pure - Dependencies should have minimal side
effects; use background tasks for non-critical operations
-
Use context managers - Wrap dependencies in context managers
when dealing with transactions or resource pools
-
Dependency composition - Compose larger dependencies from
smaller, focused ones rather than creating monolithic dependencies
-
使用yield进行清理 - 对于需要清理的资源(如数据库连接、文件句柄、网络连接),始终使用yield以确保资源被正确管理
-
利用缓存 - 使用依赖项缓存(默认启用)避免请求内的重复工作;同一请求中多次使用同一依赖项会共享同一实例
-
链式依赖项 - 从简单的子依赖项构建复杂依赖项,创建可组合、可测试、可维护的代码结构
-
基于类实现状态 - 对于需要维护状态或配置的依赖项,使用类并利用
方法实现可调用实例
-
处处使用类型提示 - 始终添加类型提示以获得更好的编辑器支持、自动验证与更完善的文档生成
-
测试时覆盖依赖项 - 使用dependency_overrides在测试时注入模拟对象,无需修改生产代码
-
全局应用横切关注点 - 将通用依赖项(日志、认证、限流)全局或路由级应用,避免重复代码
-
尽可能使用异步 - 对I/O操作使用异步依赖项以最大化性能与并发优势
-
分离关注点 - 将认证、授权与业务逻辑放在单独的依赖项中,提升可测试性与可复用性
-
文档化依赖项 - 为复杂的依赖项链添加文档字符串,尤其是在构建多层级结构时
-
使用安全工具 - 利用FastAPI的安全工具(如OAuth2PasswordBearer、HTTPBearer)实现认证模式
-
尽早验证 - 将验证依赖项放在链的早期,快速失败并提供清晰的错误信息
-
保持依赖项纯净 - 依赖项应尽量减少副作用;对于非关键操作使用后台任务
-
使用上下文管理器 - 处理事务或资源池时,将依赖项包装在上下文管理器中
-
依赖项组合 - 从更小、聚焦的依赖项组合出更大的依赖项,而非创建单体式依赖项
FastAPI DI Common Pitfalls
FastAPI依赖注入常见陷阱
-
Forgetting yield - Not using yield means resources won't be
cleaned up properly, leading to connection leaks and resource exhaustion
-
Circular dependencies - Creating dependency cycles causes
infinite loops and stack overflow errors; design dependencies in a
directed acyclic graph
-
Not using Depends() - Forgetting Depends() wrapper means
function is called directly instead of being injected, breaking the
dependency resolution
-
Overusing use_cache=False - Disabling cache unnecessarily hurts
performance by creating multiple instances of the same dependency per
request
-
Heavy dependencies - Putting too much logic in dependencies
instead of services makes them hard to test and violates single
responsibility
-
Not testing overrides - Forgetting to test with
dependency_overrides means tests may use production resources instead
of mocks
-
Mixing sync and async - Incorrectly mixing synchronous and
asynchronous dependencies can block the event loop or cause runtime
errors
-
Global state issues - Not properly managing singleton
dependencies leads to shared state bugs in concurrent requests
-
Exception handling - Not handling exceptions in dependencies
properly can leave resources in inconsistent states or leak connections
-
Type hint mistakes - Missing or incorrect type hints break
dependency injection and automatic validation
-
Ignoring dependency order - Dependencies are executed in the
order they appear in function signature; incorrect order can cause
issues
-
Not cleaning test overrides - Forgetting to reset
app.dependency_overrides after tests causes subsequent tests to fail
-
Overusing global dependencies - Applying too many dependencies
globally can hurt performance and make debugging difficult
-
Memory leaks with generators - Not properly closing resources
in finally block of generator dependencies causes memory leaks
-
Security misconfiguration - Using weak or missing security
dependencies exposes endpoints to unauthorized access
-
忘记使用yield - 不使用yield会导致资源无法被正确清理,引发连接泄漏与资源耗尽
-
循环依赖项 - 创建依赖项循环会导致无限循环与栈溢出错误;将依赖项设计为有向无环图
-
未使用Depends() - 忘记Depends()包装器会导致函数被直接调用而非注入,破坏依赖项解析
-
过度使用use_cache=False - 不必要地禁用缓存会因同一请求中创建多个同一依赖项实例而损害性能
-
过重的依赖项 - 在依赖项中放入过多逻辑而非服务中,会使其难以测试并违反单一职责原则
-
未测试覆盖项 - 忘记使用dependency_overrides测试会导致测试使用生产资源而非模拟对象
-
混合同步与异步 - 错误地混合同步与异步依赖项会阻塞事件循环或导致运行时错误
-
全局状态问题 - 未正确管理单例依赖项会导致并发请求中的共享状态bug
-
异常处理不当 - 未正确处理依赖项中的异常会导致资源处于不一致状态或连接泄漏
-
类型提示错误 - 缺失或错误的类型提示会破坏依赖注入与自动验证
-
忽略依赖项顺序 - 依赖项按函数签名中的顺序执行;错误的顺序会引发问题
-
未清理测试覆盖项 - 测试后忘记重置app.dependency_overrides会导致后续测试失败
-
过度使用全局依赖项 - 应用过多全局依赖项会损害性能并增加调试难度
-
生成器导致内存泄漏 - 未在生成器依赖项的finally块中正确关闭资源会导致内存泄漏
-
安全配置错误 - 使用弱或缺失的安全依赖项会使端点暴露于未授权访问风险
Advanced Caching Patterns
高级缓存模式
Implement sophisticated caching strategies with dependencies.
python
from fastapi import Depends
import hashlib
import json
使用依赖项实现复杂的缓存策略。
python
from fastapi import Depends
import hashlib
import json
Multi-layer cache with fallback
Multi-layer cache with fallback
class CacheLayer:
def init(
self,
memory_cache = Depends(get_memory_cache),
redis_cache = Depends(get_redis_cache)
):
self.memory = memory_cache
self.redis = redis_cache
async def get(self, key: str):
# Try memory first
value = self.memory.get(key)
if value:
return value
# Try Redis
value = await self.redis.get(key)
if value:
# Populate memory cache
self.memory.set(key, value)
return value
return None
async def set(self, key: str, value, ttl: int = 3600):
self.memory.set(key, value, ttl=min(ttl, 300))
await self.redis.set(key, value, ttl=ttl)
class CacheLayer:
def init(
self,
memory_cache = Depends(get_memory_cache),
redis_cache = Depends(get_redis_cache)
):
self.memory = memory_cache
self.redis = redis_cache
async def get(self, key: str):
# Try memory first
value = self.memory.get(key)
if value:
return value
# Try Redis
value = await self.redis.get(key)
if value:
# Populate memory cache
self.memory.set(key, value)
return value
return None
async def set(self, key: str, value, ttl: int = 3600):
self.memory.set(key, value, ttl=min(ttl, 300))
await self.redis.set(key, value, ttl=ttl)
Cache key generation
Cache key generation
def create_cache_key(*args, **kwargs):
key_data = json.dumps({'args': args, 'kwargs': kwargs}, sort_keys=True)
return hashlib.md5(key_data.encode()).hexdigest()
def create_cache_key(*args, **kwargs):
key_data = json.dumps({'args': args, 'kwargs': kwargs}, sort_keys=True)
return hashlib.md5(key_data.encode()).hexdigest()
Dependency with automatic caching
Dependency with automatic caching
def cached_dependency(ttl: int = 3600):
async def dependency(
params: dict,
cache: CacheLayer = Depends()
):
cache_key = create_cache_key(**params)
cached_value = await cache.get(cache_key)
if cached_value:
return cached_value
# Compute expensive value
value = await compute_expensive_value(params)
await cache.set(cache_key, value, ttl=ttl)
return value
return dependency
@app.get('/cached-endpoint')
async def cached_endpoint(
result = Depends(cached_dependency(ttl=1800))
):
return result
def cached_dependency(ttl: int = 3600):
async def dependency(
params: dict,
cache: CacheLayer = Depends()
):
cache_key = create_cache_key(**params)
cached_value = await cache.get(cache_key)
if cached_value:
return cached_value
# Compute expensive value
value = await compute_expensive_value(params)
await cache.set(cache_key, value, ttl=ttl)
return value
return dependency
@app.get('/cached-endpoint')
async def cached_endpoint(
result = Depends(cached_dependency(ttl=1800))
):
return result
Cache invalidation dependency
Cache invalidation dependency
class CacheInvalidator:
def init(self, cache: CacheLayer = Depends()):
self.cache = cache
self.invalidation_queue = []
def invalidate(self, pattern: str):
self.invalidation_queue.append(pattern)
async def flush(self):
for pattern in self.invalidation_queue:
await self.cache.redis.delete_pattern(pattern)
self.invalidation_queue.clear()
@app.post('/users')
async def create_user(
user: UserCreate,
db = Depends(get_db),
invalidator: CacheInvalidator = Depends()
):
new_user = await db.create_user(user)
invalidator.invalidate('users:*')
await invalidator.flush()
return new_user
class CacheInvalidator:
def init(self, cache: CacheLayer = Depends()):
self.cache = cache
self.invalidation_queue = []
def invalidate(self, pattern: str):
self.invalidation_queue.append(pattern)
async def flush(self):
for pattern in self.invalidation_queue:
await self.cache.redis.delete_pattern(pattern)
self.invalidation_queue.clear()
@app.post('/users')
async def create_user(
user: UserCreate,
db = Depends(get_db),
invalidator: CacheInvalidator = Depends()
):
new_user = await db.create_user(user)
invalidator.invalidate('users:*')
await invalidator.flush()
return new_user
Middleware-Style Dependencies
中间件风格的依赖项
Use dependencies for cross-cutting concerns.
python
from fastapi import Depends, Request
from time import time
使用依赖项处理横切关注点。
python
from fastapi import Depends, Request
from time import time
Request timing dependency
Request timing dependency
async def measure_request_time(request: Request):
start_time = time()
yield
duration = time() - start_time
print(f'{request.method} {request.url.path} took {duration:.3f}s')
async def measure_request_time(request: Request):
start_time = time()
yield
duration = time() - start_time
print(f'{request.method} {request.url.path} took {duration:.3f}s')
Request ID tracking
Request ID tracking
async def track_request_id(request: Request):
request_id = request.headers.get('X-Request-ID', str(uuid.uuid4()))
request.state.request_id = request_id
yield request_id
def get_request_id(request: Request):
return request.state.request_id
async def track_request_id(request: Request):
request_id = request.headers.get('X-Request-ID', str(uuid.uuid4()))
request.state.request_id = request_id
yield request_id
def get_request_id(request: Request):
return request.state.request_id
Rate limiting per user
Rate limiting per user
class RateLimiter:
def init(self):
self.requests = {}
async def check_rate_limit(
self,
user = Depends(get_current_user),
cache = Depends(get_cache)
):
key = f'rate_limit:{user.id}'
count = await cache.incr(key)
if count == 1:
await cache.expire(key, 60)
if count > 100:
raise HTTPException(
status_code=429,
detail='Rate limit exceeded'
)
return True
rate_limiter = RateLimiter()
@app.get('/protected')
async def protected_endpoint(
rate_limit_ok = Depends(rate_limiter.check_rate_limit)
):
return {'message': 'Success'}
class RateLimiter:
def init(self):
self.requests = {}
async def check_rate_limit(
self,
user = Depends(get_current_user),
cache = Depends(get_cache)
):
key = f'rate_limit:{user.id}'
count = await cache.incr(key)
if count == 1:
await cache.expire(key, 60)
if count > 100:
raise HTTPException(
status_code=429,
detail='Rate limit exceeded'
)
return True
rate_limiter = RateLimiter()
@app.get('/protected')
async def protected_endpoint(
rate_limit_ok = Depends(rate_limiter.check_rate_limit)
):
return {'message': 'Success'}
Request validation
Request validation
async def validate_content_type(request: Request):
content_type = request.headers.get('Content-Type')
if not content_type or 'application/json' not in content_type:
raise HTTPException(
status_code=415,
detail='Content-Type must be application/json'
)
return True
@app.post('/data', dependencies=[Depends(validate_content_type)])
async def post_data(data: dict):
return data
async def validate_content_type(request: Request):
content_type = request.headers.get('Content-Type')
if not content_type or 'application/json' not in content_type:
raise HTTPException(
status_code=415,
detail='Content-Type must be application/json'
)
return True
@app.post('/data', dependencies=[Depends(validate_content_type)])
async def post_data(data: dict):
return data