# app/api/dependencies/database.py from typing import AsyncIterator, Callable, Type from asyncpg import Connection, Pool from fastapi import Depends from starlette.requests import Request from app.db.repositories.base import BaseRepository def _get_db_pool(request: Request) -> Pool: """ 从 app.state.pool 取得连接池;若未初始化给出清晰报错。 """ pool = getattr(request.app.state, "pool", None) if pool is None: raise RuntimeError("Database pool not initialized on app.state.pool") return pool async def _get_connection_from_pool( pool: Pool = Depends(_get_db_pool), ) -> AsyncIterator[Connection]: """ 私有实现:从连接池借出一个连接,使用后自动归还。 """ async with pool.acquire() as conn: yield conn # ✅ 公共别名:供路由里直接使用 Depends(get_connection) get_connection = _get_connection_from_pool def get_repository( repo_type: Type[BaseRepository], ) -> Callable[[Connection], BaseRepository]: """ 兼容旧用法:Depends(get_repository(UserRepo)) 内部依赖 get_connection,因此两种写法都能共存。 """ def _get_repo(conn: Connection = Depends(get_connection)) -> BaseRepository: return repo_type(conn) return _get_repo