""" 基础仓库类 提供通用的 CRUD 操作封装。 """ from typing import Any, Generic, TypeVar from uuid import uuid4 from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.database import Base ModelT = TypeVar("ModelT", bound=Base) class BaseRepository(Generic[ModelT]): """ 基础仓库类 提供通用的数据库操作方法。 """ model: type[ModelT] def __init__(self, session: AsyncSession): """ 初始化仓库 Args: session: 异步数据库会话 """ self.session = session async def get_by_id(self, id: str) -> ModelT | None: """ 通过 ID 获取实体 Args: id: 实体 ID Returns: 实体对象或 None """ return await self.session.get(self.model, id) async def get_all( self, *, offset: int = 0, limit: int = 100, ) -> list[ModelT]: """ 获取所有实体 Args: offset: 偏移量 limit: 限制数量 Returns: 实体列表 """ stmt = select(self.model).offset(offset).limit(limit) result = await self.session.execute(stmt) return list(result.scalars().all()) async def count(self) -> int: """ 获取实体总数 Returns: 实体数量 """ stmt = select(func.count()).select_from(self.model) result = await self.session.execute(stmt) return result.scalar() or 0 async def create(self, **kwargs: Any) -> ModelT: """ 创建新实体 Args: **kwargs: 实体属性 Returns: 新创建的实体 """ if "id" not in kwargs: kwargs["id"] = str(uuid4()) entity = self.model(**kwargs) self.session.add(entity) await self.session.flush() await self.session.refresh(entity) return entity async def update( self, entity: ModelT, **kwargs: Any, ) -> ModelT: """ 更新实体 Args: entity: 要更新的实体 **kwargs: 要更新的属性 Returns: 更新后的实体 """ for key, value in kwargs.items(): if hasattr(entity, key): setattr(entity, key, value) await self.session.flush() await self.session.refresh(entity) return entity async def delete(self, entity: ModelT) -> None: """ 删除实体 Args: entity: 要删除的实体 """ await self.session.delete(entity) await self.session.flush() async def commit(self) -> None: """提交事务""" await self.session.commit() async def rollback(self) -> None: """回滚事务""" await self.session.rollback()