2025-12-10 12:02:17 +08:00

505 lines
19 KiB
Python

import asyncio
import logging
from datetime import datetime
from io import BytesIO
from fastapi import APIRouter, Body, Depends, HTTPException, Request, WebSocket, WebSocketDisconnect, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import func, select
from sqlalchemy.orm import selectinload
from fastapi.responses import StreamingResponse
from backend.api.deps import AuthUser, get_current_user
from backend.core.security import decode_token
from backend.db.session import get_session
from backend.modules.customers.models import Customer
from backend.modules.instances.events import instance_event_manager
from backend.modules.instances.models import Instance
from backend.modules.instances.schemas import (
InstanceCreateRequest,
InstanceCreateResponse,
InstanceFilterParams,
InstanceListResponse,
InstanceOut,
InstanceSyncRequest,
BatchInstancesActionIn,
BatchInstancesActionOut,
BatchInstancesByIpIn,
BatchInstancesByIpOut,
InstanceIdsExportIn,
)
from backend.modules.instances.service import (
enqueue_action,
ensure_credential_access,
list_instances,
create_instance,
sync_instances,
batch_instances_action,
batch_instances_by_ips,
export_instances,
export_instances_by_ids,
)
from backend.modules.jobs.models import JobItemAction
from backend.modules.audit.models import AuditAction, AuditLog, AuditResourceType
from backend.modules.instances.constants import AWS_REGIONS
from backend.modules.instances import aws_ops
from backend.modules.users.models import RoleName, User
router = APIRouter(prefix="/api/v1/instances", tags=["instances"])
logger = logging.getLogger(__name__)
async def _auth_websocket(websocket: WebSocket, session: AsyncSession) -> AuthUser | None:
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=4401, reason="Missing token")
return None
try:
payload = decode_token(token)
except ValueError:
await websocket.close(code=4401, reason="Invalid token")
return None
user_id = payload.get("sub")
if not user_id:
await websocket.close(code=4401, reason="Invalid token payload")
return None
user = await session.scalar(
select(User).where(User.id == int(user_id)).options(selectinload(User.role), selectinload(User.customer))
)
if not user or not user.is_active:
await websocket.close(code=4403, reason="User disabled or not found")
return None
role_name = payload.get("role") or (user.role.name if user.role else RoleName.CUSTOMER_USER.value)
return AuthUser(user=user, role_name=role_name, customer_id=user.customer_id, token=token)
@router.websocket("/ws")
async def instance_events(
websocket: WebSocket,
session: AsyncSession = Depends(get_session),
):
auth_user = await _auth_websocket(websocket, session)
if not auth_user:
return
await instance_event_manager.connect(websocket, auth_user)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
await instance_event_manager.disconnect(websocket)
except Exception:
await instance_event_manager.disconnect(websocket)
@router.get("", response_model=InstanceListResponse)
async def list_instances_endpoint(
filters: InstanceFilterParams = Depends(),
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> InstanceListResponse:
items, total = await list_instances(session, filters.model_dump(exclude_none=True), auth_user.user)
return InstanceListResponse(items=[InstanceOut.model_validate(i) for i in items], total=total)
@router.post("/create", response_model=InstanceCreateResponse, status_code=status.HTTP_201_CREATED)
async def create_instance_endpoint(
payload: InstanceCreateRequest,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> InstanceCreateResponse:
try:
result = await create_instance(session, payload.model_dump(), auth_user.user)
except HTTPException:
raise
except Exception as exc: # pragma: no cover
logger.exception("create_instance failed")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc))
return InstanceCreateResponse(
**InstanceOut.model_validate(result["instance"]).model_dump(),
login_username=result.get("login_username"),
login_password=result.get("login_password"),
)
@router.post("/batch/action", response_model=BatchInstancesActionOut)
async def batch_instances_action_endpoint(
payload: BatchInstancesActionIn,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> BatchInstancesActionOut:
return await batch_instances_action(session, payload, auth_user.user)
@router.post("/batch/by-ips", response_model=BatchInstancesByIpOut)
async def batch_instances_by_ips_endpoint(
payload: BatchInstancesByIpIn,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
) -> BatchInstancesByIpOut:
result = await batch_instances_by_ips(session, payload, auth_user.user)
return BatchInstancesByIpOut(**result)
async def _action_endpoint(
instance_id: int,
action: JobItemAction,
session: AsyncSession,
auth_user: AuthUser,
) -> dict:
instance = await session.get(Instance, instance_id)
if not instance:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Instance not found")
if auth_user.role_name != "ADMIN" and instance.customer_id != auth_user.customer_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
if not instance.credential_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Instance missing credential")
await ensure_credential_access(session, instance.credential_id, auth_user.user)
job = await enqueue_action(session, instance, action, auth_user.user)
session.add(
AuditLog(
user_id=auth_user.user.id,
customer_id=instance.customer_id,
action={
JobItemAction.START: AuditAction.INSTANCE_START,
JobItemAction.STOP: AuditAction.INSTANCE_STOP,
JobItemAction.REBOOT: AuditAction.INSTANCE_REBOOT,
JobItemAction.TERMINATE: AuditAction.INSTANCE_TERMINATE,
}[action],
resource_type=AuditResourceType.INSTANCE,
resource_id=instance.id,
description=f"{action.value} instance {instance.instance_id}",
)
)
await session.commit()
return {"job_uuid": job.job_uuid, "job_id": job.id}
@router.post("/{instance_id}/start")
async def start_instance(
instance_id: int,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
return await _action_endpoint(instance_id, JobItemAction.START, session, auth_user)
@router.post("/{instance_id}/stop")
async def stop_instance(
instance_id: int,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
return await _action_endpoint(instance_id, JobItemAction.STOP, session, auth_user)
@router.post("/{instance_id}/reboot")
async def reboot_instance(
instance_id: int,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
return await _action_endpoint(instance_id, JobItemAction.REBOOT, session, auth_user)
@router.post("/{instance_id}/terminate")
async def terminate_instance(
instance_id: int,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
return await _action_endpoint(instance_id, JobItemAction.TERMINATE, session, auth_user)
@router.post("/sync")
async def sync_instances_endpoint(
payload: InstanceSyncRequest,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
job = await sync_instances(session, payload.credential_id, payload.region, auth_user.user, payload.customer_id)
return {"job_uuid": job.job_uuid, "job_id": job.id, "total": job.total_count}
@router.get("/quota/available")
async def available_instance_quota(
customer_id: int | None = None,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
# admin can omit customer_id and fall back to global view
if auth_user.role_name == "ADMIN":
target_customer_id = customer_id or auth_user.customer_id
else:
target_customer_id = auth_user.customer_id
if target_customer_id is None:
quota = 999999 # effectively unlimited when not scoped to a customer
active_count = await session.scalar(
select(func.count(Instance.id)).where(
Instance.status != "TERMINATED",
Instance.desired_status != "TERMINATED",
)
)
available = max(0, quota - (active_count or 0))
return {"quota": quota, "in_use": active_count or 0, "available": available}
customer = await session.get(Customer, target_customer_id)
if not customer:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Customer not found")
quota = customer.quota_instances if customer.quota_instances is not None else 999999
active_count = await session.scalar(
select(func.count(Instance.id)).where(
Instance.customer_id == target_customer_id,
Instance.status != "TERMINATED",
Instance.desired_status != "TERMINATED",
)
)
available = max(0, quota - (active_count or 0))
return {"quota": quota, "in_use": active_count or 0, "available": available}
@router.get("/quota/capacity")
async def capacity_by_region(
credential_id: int,
region: str,
customer_id: int | None = None,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
# permission + credential check
cred = await ensure_credential_access(session, credential_id, auth_user.user)
if auth_user.role_name == "ADMIN":
target_customer_id = customer_id or auth_user.customer_id
else:
target_customer_id = auth_user.customer_id
if target_customer_id is None:
# admin without customer scope: treat as global (only constrained by AWS quota)
quota = 999999
active_count = await session.scalar(
select(func.count(Instance.id)).where(
Instance.status != "TERMINATED",
Instance.desired_status != "TERMINATED",
Instance.region == region,
)
)
available = max(0, quota - (active_count or 0))
else:
customer = await session.get(Customer, target_customer_id)
if not customer:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Customer not found")
quota = customer.quota_instances if customer.quota_instances is not None else 999999
active_count = await session.scalar(
select(func.count(Instance.id)).where(
Instance.customer_id == target_customer_id,
Instance.status != "TERMINATED",
Instance.desired_status != "TERMINATED",
Instance.region == region,
)
)
available = max(0, quota - (active_count or 0))
# try to respect AWS regional quota if possible (best-effort)
aws_available = None
try:
resp = await asyncio.to_thread(
aws_ops.get_service_quota, cred, region, "ec2", "L-1216C47A"
) # on-demand instances per region
if resp and resp.get("Quota", {}).get("Value") is not None:
quota_val = int(resp["Quota"]["Value"])
aws_available = max(0, quota_val - (active_count or 0))
available = min(available, aws_available)
except Exception:
pass
return {
"quota": quota,
"in_use_region": active_count or 0,
"available": available,
"aws_available": aws_available,
"region": region,
"credential_id": credential_id,
}
@router.get("/export")
async def export_instances_endpoint(
filters: InstanceFilterParams = Depends(),
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
payload = filters.model_dump(exclude_none=True)
payload.pop("offset", None)
payload.pop("limit", None)
content = await export_instances(session, payload, auth_user.user)
credential = payload.get("credential_id") or "all"
region = payload.get("region") or "all"
stamp = datetime.now().strftime("%Y%m%d_%H%M")
filename = f"instances_{credential}_{region}_{stamp}.xlsx"
return StreamingResponse(
BytesIO(content),
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.post("/export-by-ids")
async def export_by_ids_endpoint(
payload: InstanceIdsExportIn,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
content = await export_instances_by_ids(session, payload.instance_ids, auth_user.user)
stamp = datetime.now().strftime("%Y%m%d_%H%M")
filename = f"instances_selected_{stamp}.xlsx"
return StreamingResponse(
BytesIO(content),
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get("/meta/aws/regions")
async def list_regions(
credential_id: int,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
cred = await ensure_credential_access(session, credential_id, auth_user.user)
try:
resp = await asyncio.to_thread(aws_ops.describe_regions, cred)
regions = resp.get("Regions", [])
items = []
for r in regions:
name = r.get("RegionName")
if not name:
continue
meta = AWS_REGIONS.get(name, {"en": name, "zh": ""})
items.append({"id": name, "label_en": meta["en"], "label_zh": meta["zh"]})
if items:
return items
except Exception:
pass
return [
{"id": region, "label_en": meta["en"], "label_zh": meta["zh"]}
for region, meta in AWS_REGIONS.items()
]
@router.get("/meta/aws/network")
async def aws_network(
credential_id: int,
region: str,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
cred = await ensure_credential_access(session, credential_id, auth_user.user)
region_use = region or cred.default_region
vpcs = await asyncio.to_thread(aws_ops.describe_vpcs, cred, region_use)
subnets = await asyncio.to_thread(aws_ops.describe_subnets, cred, region_use)
sgs = await asyncio.to_thread(aws_ops.describe_security_groups, cred, region_use)
return {
"vpcs": [
{
"vpc_id": vpc.get("VpcId"),
"name": next((t["Value"] for t in vpc.get("Tags", []) if t.get("Key") == "Name"), None),
"cidr": vpc.get("CidrBlock"),
}
for vpc in vpcs.get("Vpcs", [])
],
"subnets": [
{
"subnet_id": sn.get("SubnetId"),
"vpc_id": sn.get("VpcId"),
"az": sn.get("AvailabilityZone"),
"cidr": sn.get("CidrBlock"),
"name": next((t["Value"] for t in sn.get("Tags", []) if t.get("Key") == "Name"), None),
}
for sn in subnets.get("Subnets", [])
],
"security_groups": [
{
"group_id": sg.get("GroupId"),
"name": sg.get("GroupName"),
"desc": sg.get("Description"),
"vpc_id": sg.get("VpcId"),
}
for sg in sgs.get("SecurityGroups", [])
],
}
@router.get("/meta/aws/keypairs")
async def aws_keypairs(
credential_id: int,
region: str,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
cred = await ensure_credential_access(session, credential_id, auth_user.user)
region_use = region or cred.default_region
resp = await asyncio.to_thread(aws_ops.describe_key_pairs, cred, region_use)
return [
{
"key_name": kp.get("KeyName"),
"key_pair_id": kp.get("KeyPairId"),
"fingerprint": kp.get("KeyFingerprint"),
}
for kp in resp.get("KeyPairs", [])
]
@router.post("/meta/aws/keypairs", status_code=status.HTTP_201_CREATED)
async def aws_create_keypair(
credential_id: int = Body(...),
region: str = Body(...),
key_name: str = Body(...),
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
cred = await ensure_credential_access(session, credential_id, auth_user.user)
region_use = region or cred.default_region
resp = await asyncio.to_thread(aws_ops.create_key_pair, cred, region_use, key_name)
return {
"key_name": resp.get("KeyName"),
"key_pair_id": resp.get("KeyPairId"),
"fingerprint": resp.get("KeyFingerprint"),
"material": resp.get("KeyMaterial"),
}
@router.get("/meta/aws/instance-types")
async def aws_instance_types(
credential_id: int,
region: str,
session: AsyncSession = Depends(get_session),
auth_user: AuthUser = Depends(get_current_user),
):
cred = await ensure_credential_access(session, credential_id, auth_user.user)
region_use = region or cred.default_region
is_admin = auth_user.role_name == "ADMIN"
filters = None if is_admin else [{"Name": "instance-type", "Values": ["t3.micro", "t3.small", "t3.medium"]}]
resp = await asyncio.to_thread(aws_ops.describe_instance_types, cred, region_use, filters)
allowed_customer = {"t3.micro", "t3.small", "t3.medium"}
items = []
for it in resp:
itype = it.get("InstanceType")
if not itype:
continue
if not is_admin and itype not in allowed_customer:
continue
vcpu = (it.get("VCpuInfo") or {}).get("DefaultVCpus")
mem = (it.get("MemoryInfo") or {}).get("SizeInMiB")
net = (it.get("NetworkInfo") or {}).get("NetworkPerformance")
items.append(
{
"instance_type": itype,
"vcpu": vcpu,
"memory_mib": mem,
"network_performance": net,
}
)
return items