Loading...
Loading...
Compare original and translation side by side
select()QueryMapped[T]mapped_column()Session.execute()AsyncSessionselect()QueryMapped[T]mapped_column()Session.execute()AsyncSessionundefinedundefinedundefinedundefinedfrom datetime import datetime
from typing import Optional
from sqlalchemy import String, DateTime, ForeignKey, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationshipfrom datetime import datetime
from typing import Optional
from sqlalchemy import String, DateTime, ForeignKey, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship# Primary key
id: Mapped[int] = mapped_column(primary_key=True)
# Required fields
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
username: Mapped[str] = mapped_column(String(50), unique=True)
hashed_password: Mapped[str] = mapped_column(String(255))
# Optional fields
full_name: Mapped[Optional[str]] = mapped_column(String(100))
is_active: Mapped[bool] = mapped_column(default=True)
# Timestamps with server defaults
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now()
)
# Relationships
posts: Mapped[list["Post"]] = relationship(back_populates="author")
def __repr__(self) -> str:
return f"User(id={self.id}, email={self.email})"undefined# 主键
id: Mapped[int] = mapped_column(primary_key=True)
# 必填字段
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
username: Mapped[str] = mapped_column(String(50), unique=True)
hashed_password: Mapped[str] = mapped_column(String(255))
# 可选字段
full_name: Mapped[Optional[str]] = mapped_column(String(100))
is_active: Mapped[bool] = mapped_column(default=True)
# 带服务器默认值的时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now()
)
# 关系映射
posts: Mapped[list["Post"]] = relationship(back_populates="author")
def __repr__(self) -> str:
return f"User(id={self.id}, email={self.email})"undefinedclass Post(Base):
__tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str] = mapped_column(String(200))
content: Mapped[str]
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
# Relationship with back_populates
author: Mapped["User"] = relationship(back_populates="posts")
tags: Mapped[list["Tag"]] = relationship(
secondary="post_tags",
back_populates="posts"
)from sqlalchemy import Table, Column, Integer, ForeignKeyclass Post(Base):
__tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str] = mapped_column(String(200))
content: Mapped[str]
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
# 反向关联关系
author: Mapped["User"] = relationship(back_populates="posts")
tags: Mapped[list["Tag"]] = relationship(
secondary="post_tags",
back_populates="posts"
)from sqlalchemy import Table, Column, Integer, ForeignKeyid: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(50), unique=True)
posts: Mapped[list["Post"]] = relationship(
secondary=post_tags,
back_populates="tags"
)undefinedid: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(50), unique=True)
posts: Mapped[list["Post"]] = relationship(
secondary=post_tags,
back_populates="tags"
)undefinedfrom sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePoolfrom sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePoolundefinedundefinedfrom typing import Generator
def get_db() -> Generator[Session, None, None]:
"""Database session dependency for FastAPI."""
db = SessionLocal()
try:
yield db
finally:
db.close()from typing import Generator
def get_db() -> Generator[Session, None, None]:
"""FastAPI的数据库会话依赖。"""
db = SessionLocal()
try:
yield db
finally:
db.close()undefinedundefinedfrom sqlalchemy import select, and_, or_, desc, funcfrom sqlalchemy import select, and_, or_, desc, funcundefinedundefinedundefinedundefinedundefinedundefinedfrom sqlalchemy.orm import selectinload, joinedloadfrom sqlalchemy.orm import selectinload, joinedloadundefinedundefineddef create_user(db: Session, email: str, username: str, password: str):
"""Create new user."""
user = User(
email=email,
username=username,
hashed_password=hash_password(password)
)
db.add(user)
db.commit()
db.refresh(user) # Get updated fields (id, timestamps)
return userdef create_user(db: Session, email: str, username: str, password: str):
"""创建新用户。"""
user = User(
email=email,
username=username,
hashed_password=hash_password(password)
)
db.add(user)
db.commit()
db.refresh(user) # 获取更新后的字段(如id、时间戳)
return userundefinedundefineddef get_user_by_email(db: Session, email: str) -> Optional[User]:
"""Get user by email."""
stmt = select(User).where(User.email == email)
return db.execute(stmt).scalar_one_or_none()
def get_users(
db: Session,
skip: int = 0,
limit: int = 100
) -> list[User]:
"""Get paginated users."""
stmt = (
select(User)
.where(User.is_active == True)
.order_by(User.created_at.desc())
.offset(skip)
.limit(limit)
)
return db.execute(stmt).scalars().all()def get_user_by_email(db: Session, email: str) -> Optional[User]:
"""通过邮箱查询用户。"""
stmt = select(User).where(User.email == email)
return db.execute(stmt).scalar_one_or_none()
def get_users(
db: Session,
skip: int = 0,
limit: int = 100
) -> list[User]:
"""分页查询用户。"""
stmt = (
select(User)
.where(User.is_active == True)
.order_by(User.created_at.desc())
.offset(skip)
.limit(limit)
)
return db.execute(stmt).scalars().all()def update_user(db: Session, user_id: int, **kwargs):
"""Update user fields."""
stmt = select(User).where(User.id == user_id)
user = db.execute(stmt).scalar_one_or_none()
if not user:
return None
for key, value in kwargs.items():
setattr(user, key, value)
db.commit()
db.refresh(user)
return userdef update_user(db: Session, user_id: int, **kwargs):
"""更新用户字段。"""
stmt = select(User).where(User.id == user_id)
user = db.execute(stmt).scalar_one_or_none()
if not user:
return None
for key, value in kwargs.items():
setattr(user, key, value)
db.commit()
db.refresh(user)
return userundefinedundefineddef delete_user(db: Session, user_id: int) -> bool:
"""Delete user."""
stmt = select(User).where(User.id == user_id)
user = db.execute(stmt).scalar_one_or_none()
if not user:
return False
db.delete(user)
db.commit()
return Truedef delete_user(db: Session, user_id: int) -> bool:
"""删除用户。"""
stmt = select(User).where(User.id == user_id)
user = db.execute(stmt).scalar_one_or_none()
if not user:
return False
db.delete(user)
db.commit()
return Trueundefinedundefinedfrom contextlib import contextmanager
@contextmanager
def get_db_session():
"""Session context manager."""
session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()from contextlib import contextmanager
@contextmanager
def get_db_session():
"""会话上下文管理器。"""
session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()undefinedundefineddef transfer_money(db: Session, from_user_id: int, to_user_id: int, amount: float):
"""Transfer money between users with transaction."""
try:
# Begin nested transaction
with db.begin_nested():
# Deduct from sender
stmt = select(User).where(User.id == from_user_id).with_for_update()
sender = db.execute(stmt).scalar_one()
sender.balance -= amount
# Add to receiver
stmt = select(User).where(User.id == to_user_id).with_for_update()
receiver = db.execute(stmt).scalar_one()
receiver.balance += amount
db.commit()
except Exception as e:
db.rollback()
raisedef transfer_money(db: Session, from_user_id: int, to_user_id: int, amount: float):
"""通过事务实现用户间转账。"""
try:
# 开启嵌套事务
with db.begin_nested():
# 从转出方扣除金额
stmt = select(User).where(User.id == from_user_id).with_for_update()
sender = db.execute(stmt).scalar_one()
sender.balance -= amount
# 给转入方增加金额
stmt = select(User).where(User.id == to_user_id).with_for_update()
receiver = db.execute(stmt).scalar_one()
receiver.balance += amount
db.commit()
except Exception as e:
db.rollback()
raisefrom sqlalchemy.ext.asyncio import (
create_async_engine,
AsyncSession,
async_sessionmaker
)from sqlalchemy.ext.asyncio import (
create_async_engine,
AsyncSession,
async_sessionmaker
)undefinedundefinedasync def get_user_async(user_id: int) -> Optional[User]:
"""Get user asynchronously."""
async with AsyncSessionLocal() as session:
stmt = select(User).where(User.id == user_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def create_user_async(email: str, username: str) -> User:
"""Create user asynchronously."""
async with AsyncSessionLocal() as session:
user = User(email=email, username=username)
session.add(user)
await session.commit()
await session.refresh(user)
return userasync def get_user_async(user_id: int) -> Optional[User]:
"""异步查询用户。"""
async with AsyncSessionLocal() as session:
stmt = select(User).where(User.id == user_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def create_user_async(email: str, username: str) -> User:
"""异步创建用户。"""
async with AsyncSessionLocal() as session:
user = User(email=email, username=username)
session.add(user)
await session.commit()
await session.refresh(user)
return userundefinedundefinedundefinedundefinedundefinedundefinedundefinedundefinedconnectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()undefinedconnectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()undefinedundefinedundefinedundefinedundefinedundefinedundefinedundefinedundefinedfrom fastapi import FastAPI, Depends, HTTPException, status
from sqlalchemy.orm import Session
from pydantic import BaseModel, EmailStr
from typing import List
app = FastAPI()from fastapi import FastAPI, Depends, HTTPException, status
from sqlalchemy.orm import Session
from pydantic import BaseModel, EmailStr
from typing import List
app = FastAPI()class Config:
from_attributes = True # SQLAlchemy 2.0 (was orm_mode)class Config:
from_attributes = True # SQLAlchemy 2.0版本(原orm_mode)# Create user
db_user = User(
email=user.email,
username=user.username,
hashed_password=hash_password(user.password)
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_userif not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return userif not db_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
db_user.email = user_update.email
db_user.username = user_update.username
db.commit()
db.refresh(db_user)
return db_userif not db_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
db.delete(db_user)
db.commit()undefined# 创建用户
db_user = User(
email=user.email,
username=user.username,
hashed_password=hash_password(user.password)
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_userif not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
return userif not db_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
db_user.email = user_update.email
db_user.username = user_update.username
db.commit()
db.refresh(db_user)
return db_userif not db_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
db.delete(db_user)
db.commit()undefinedimport pytest
from sqlalchemy import create_engine, StaticPool
from sqlalchemy.orm import sessionmakerimport pytest
from sqlalchemy import create_engine, StaticPool
from sqlalchemy.orm import sessionmaker# Create tables
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine
)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)undefined# 创建数据表
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine
)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)undefineddef test_create_user(db_session):
"""Test user creation."""
user = User(email="new@example.com", username="newuser")
db_session.add(user)
db_session.commit()
assert user.id is not None
assert user.email == "new@example.com"
assert user.created_at is not None
def test_query_user(db_session, test_user):
"""Test user query."""
stmt = select(User).where(User.email == "test@example.com")
found_user = db_session.execute(stmt).scalar_one()
assert found_user.id == test_user.id
assert found_user.username == test_user.username
def test_update_user(db_session, test_user):
"""Test user update."""
test_user.username = "updated"
db_session.commit()
stmt = select(User).where(User.id == test_user.id)
updated_user = db_session.execute(stmt).scalar_one()
assert updated_user.username == "updated"
def test_delete_user(db_session, test_user):
"""Test user deletion."""
user_id = test_user.id
db_session.delete(test_user)
db_session.commit()
stmt = select(User).where(User.id == user_id)
assert db_session.execute(stmt).scalar_one_or_none() is Nonedef test_create_user(db_session):
"""测试用户创建。"""
user = User(email="new@example.com", username="newuser")
db_session.add(user)
db_session.commit()
assert user.id is not None
assert user.email == "new@example.com"
assert user.created_at is not None
def test_query_user(db_session, test_user):
"""测试用户查询。"""
stmt = select(User).where(User.email == "test@example.com")
found_user = db_session.execute(stmt).scalar_one()
assert found_user.id == test_user.id
assert found_user.username == test_user.username
def test_update_user(db_session, test_user):
"""测试用户更新。"""
test_user.username = "updated"
db_session.commit()
stmt = select(User).where(User.id == test_user.id)
updated_user = db_session.execute(stmt).scalar_one()
assert updated_user.username == "updated"
def test_delete_user(db_session, test_user):
"""测试用户删除。"""
user_id = test_user.id
db_session.delete(test_user)
db_session.commit()
stmt = select(User).where(User.id == user_id)
assert db_session.execute(stmt).scalar_one_or_none() is Noneundefinedundefinedemail: Mapped[str] = mapped_column(String(255), index=True, unique=True)
created_at: Mapped[datetime] = mapped_column(index=True)
# Composite index
__table_args__ = (
Index('ix_user_email_active', 'email', 'is_active'),
)email: Mapped[str] = mapped_column(String(255), index=True, unique=True)
created_at: Mapped[datetime] = mapped_column(index=True)
# 复合索引
__table_args__ = (
Index('ix_user_email_active', 'email', 'is_active'),
)undefinedundefinedundefinedundefinedundefinedundefinedundefinedundefinedundefinedundefinedMapped[T]mapped_column()selectinloadjoinedloadAsyncSessionNoResultFoundMultipleResultsFoundMapped[T]mapped_column()selectinloadjoinedloadAsyncSessionNoResultFoundMultipleResultsFoundfrom typing import Generic, TypeVar, Type
from sqlalchemy.orm import Session
T = TypeVar('T', bound=Base)
class BaseRepository(Generic[T]):
def __init__(self, model: Type[T], db: Session):
self.model = model
self.db = db
def get(self, id: int) -> Optional[T]:
stmt = select(self.model).where(self.model.id == id)
return self.db.execute(stmt).scalar_one_or_none()
def get_all(self, skip: int = 0, limit: int = 100) -> list[T]:
stmt = select(self.model).offset(skip).limit(limit)
return self.db.execute(stmt).scalars().all()
def create(self, obj: T) -> T:
self.db.add(obj)
self.db.commit()
self.db.refresh(obj)
return obj
def delete(self, id: int) -> bool:
obj = self.get(id)
if obj:
self.db.delete(obj)
self.db.commit()
return True
return Falsefrom typing import Generic, TypeVar, Type
from sqlalchemy.orm import Session
T = TypeVar('T', bound=Base)
class BaseRepository(Generic[T]):
def __init__(self, model: Type[T], db: Session):
self.model = model
self.db = db
def get(self, id: int) -> Optional[T]:
stmt = select(self.model).where(self.model.id == id)
return self.db.execute(stmt).scalar_one_or_none()
def get_all(self, skip: int = 0, limit: int = 100) -> list[T]:
stmt = select(self.model).offset(skip).limit(limit)
return self.db.execute(stmt).scalars().all()
def create(self, obj: T) -> T:
self.db.add(obj)
self.db.commit()
self.db.refresh(obj)
return obj
def delete(self, id: int) -> bool:
obj = self.get(id)
if obj:
self.db.delete(obj)
self.db.commit()
return True
return Falseundefinedundefinedclass SoftDeleteMixin:
deleted_at: Mapped[Optional[datetime]] = mapped_column(default=None)
@property
def is_deleted(self) -> bool:
return self.deleted_at is not None
class User(Base, SoftDeleteMixin):
__tablename__ = "users"
# ... fieldsclass SoftDeleteMixin:
deleted_at: Mapped[Optional[datetime]] = mapped_column(default=None)
@property
def is_deleted(self) -> bool:
return self.deleted_at is not None
class User(Base, SoftDeleteMixin):
__tablename__ = "users"
# ... 字段定义undefinedundefinedclass AuditMixin:
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now()
)
created_by: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id"))
updated_by: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id"))
class Post(Base, AuditMixin):
__tablename__ = "posts"
# ... fieldsclass AuditMixin:
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now()
)
created_by: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id"))
updated_by: Mapped[Optional[int]] = mapped_column(ForeignKey("users.id"))
class Post(Base, AuditMixin):
__tablename__ = "posts"
# ... 字段定义