2025-12-04 10:04:21 +08:00

46 lines
1.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# app/api/dependencies/database.py
from typing import AsyncIterator, Callable, Type
from asyncpg import Connection, Pool
from fastapi import Depends
from starlette.requests import Request
from app.db.repositories.base import BaseRepository
def _get_db_pool(request: Request) -> Pool:
"""
从 app.state.pool 取得连接池;若未初始化给出清晰报错。
"""
pool = getattr(request.app.state, "pool", None)
if pool is None:
raise RuntimeError("Database pool not initialized on app.state.pool")
return pool
async def _get_connection_from_pool(
pool: Pool = Depends(_get_db_pool),
) -> AsyncIterator[Connection]:
"""
私有实现:从连接池借出一个连接,使用后自动归还。
"""
async with pool.acquire() as conn:
yield conn
# ✅ 公共别名:供路由里直接使用 Depends(get_connection)
get_connection = _get_connection_from_pool
def get_repository(
repo_type: Type[BaseRepository],
) -> Callable[[Connection], BaseRepository]:
"""
兼容旧用法Depends(get_repository(UserRepo))
内部依赖 get_connection因此两种写法都能共存。
"""
def _get_repo(conn: Connection = Depends(get_connection)) -> BaseRepository:
return repo_type(conn)
return _get_repo