from __future__ import annotations import asyncio from datetime import datetime, timezone from typing import Dict, List, Optional from uuid import uuid4 from fastapi import APIRouter, Depends, HTTPException, Request, status from sqlalchemy import and_, func, or_, select from sqlalchemy.dialects.mysql import insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from .. import aws_ops from ..db import SessionLocal, get_session from ..dependencies import AuthUser, get_current_user from ..models import ( AWSCredential, AuditAction, AuditResourceType, CustomerCredential, Instance, InstanceDesiredStatus, InstanceStatus, Job, JobItem, JobItemAction, JobItemResourceType, JobItemStatus, JobStatus, JobType, ) from ..schemas import ( InstanceCreateRequest, InstanceFilterParams, InstanceListResponse, InstanceOut, InstanceSyncRequest, JobOut, ) from ..utils.audit import create_audit_log router = APIRouter(prefix="/api/v1/instances", tags=["instances"]) STATE_MAP: Dict[str, InstanceStatus] = { "pending": InstanceStatus.PENDING, "running": InstanceStatus.RUNNING, "stopping": InstanceStatus.STOPPING, "stopped": InstanceStatus.STOPPED, "shutting-down": InstanceStatus.SHUTTING_DOWN, "terminated": InstanceStatus.TERMINATED, } JOB_TYPE_BY_ACTION = { JobItemAction.START: JobType.START_INSTANCES, JobItemAction.STOP: JobType.STOP_INSTANCES, JobItemAction.REBOOT: JobType.REBOOT_INSTANCES, JobItemAction.TERMINATE: JobType.TERMINATE_INSTANCES, } DESIRED_BY_ACTION = { JobItemAction.START: InstanceDesiredStatus.RUNNING, JobItemAction.STOP: InstanceDesiredStatus.STOPPED, JobItemAction.REBOOT: InstanceDesiredStatus.RUNNING, JobItemAction.TERMINATE: InstanceDesiredStatus.TERMINATED, } def _map_state(state: Optional[str]) -> InstanceStatus: if not state: return InstanceStatus.UNKNOWN return STATE_MAP.get(state.lower(), InstanceStatus.UNKNOWN) def _extract_name_tag(tags: Optional[List[Dict[str, str]]]) -> Optional[str]: if not tags: return None for tag in tags: if tag.get("Key") == "Name": return tag.get("Value") return None async def _ensure_credential_access( credential_id: int, auth_user: AuthUser, session: AsyncSession ) -> AWSCredential: cred = await session.get(AWSCredential, credential_id) if not cred or not cred.is_active: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found or disabled") if auth_user.role_name == "ADMIN": return cred mapping = await session.scalar( select(CustomerCredential).where( and_( CustomerCredential.customer_id == auth_user.customer_id, CustomerCredential.credential_id == credential_id, CustomerCredential.is_allowed == 1, ) ) ) if not mapping: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Credential not allowed") return cred @router.get("", response_model=InstanceListResponse) async def list_instances( filters: InstanceFilterParams = Depends(), session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> InstanceListResponse: query = select(Instance) if auth_user.role_name != "ADMIN": query = query.where(Instance.customer_id == auth_user.customer_id) elif filters.customer_id: query = query.where(Instance.customer_id == filters.customer_id) if filters.credential_id: query = query.where(Instance.credential_id == filters.credential_id) if filters.account_id: query = query.where(Instance.account_id == filters.account_id) if filters.region: query = query.where(Instance.region == filters.region) if filters.status: query = query.where(Instance.status == filters.status) if filters.keyword: pattern = f"%{filters.keyword}%" query = query.where( or_( Instance.name_tag.ilike(pattern), Instance.instance_id.ilike(pattern), Instance.public_ip.ilike(pattern), Instance.private_ip.ilike(pattern), ) ) total = await session.scalar(select(func.count()).select_from(query.subquery())) instances = ( await session.scalars(query.order_by(Instance.updated_at.desc()).offset(filters.offset).limit(filters.limit)) ).all() return InstanceListResponse(items=[InstanceOut.model_validate(i) for i in instances], total=total or 0) @router.post("/create", response_model=InstanceOut, status_code=status.HTTP_201_CREATED) async def create_instance( payload: InstanceCreateRequest, request: Request, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> InstanceOut: customer_id = payload.customer_id or auth_user.customer_id if not customer_id: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="customer_id required") cred = await _ensure_credential_access(payload.credential_id, auth_user, session) if cred.account_id != payload.account_id: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="account_id mismatch") region = payload.region or cred.default_region try: resp = await asyncio.to_thread( aws_ops.run_instances, cred, region, payload.ami_id, payload.instance_type, payload.key_name, payload.security_groups, payload.subnet_id, 1, 1, payload.name_tag, ) except Exception as exc: # pragma: no cover - AWS failure path raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc instances = resp.get("Instances") or [] if not instances: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="AWS did not return instances") data = instances[0] instance_id = data.get("InstanceId") name_tag = payload.name_tag or _extract_name_tag(data.get("Tags")) security_groups = [sg.get("GroupId") for sg in data.get("SecurityGroups", []) if sg.get("GroupId")] db_instance = Instance( customer_id=customer_id, credential_id=cred.id, account_id=cred.account_id, region=region, az=(data.get("Placement") or {}).get("AvailabilityZone"), instance_id=instance_id, name_tag=name_tag, instance_type=payload.instance_type, ami_id=payload.ami_id, key_name=payload.key_name, public_ip=data.get("PublicIpAddress"), private_ip=data.get("PrivateIpAddress"), status=_map_state((data.get("State") or {}).get("Name")), desired_status=InstanceDesiredStatus.RUNNING, security_groups=security_groups, subnet_id=data.get("SubnetId"), vpc_id=data.get("VpcId"), launched_at=data.get("LaunchTime"), last_sync=datetime.now(timezone.utc), last_cloud_state={"state": data.get("State"), "tags": data.get("Tags")}, ) session.add(db_instance) await session.commit() await session.refresh(db_instance) await create_audit_log( session, user_id=auth_user.user.id, customer_id=customer_id, action=AuditAction.INSTANCE_CREATE, resource_type=AuditResourceType.INSTANCE, resource_id=db_instance.id, description=f"Create instance {db_instance.instance_id}", payload=payload.model_dump(), request=request, ) await session.commit() return InstanceOut.model_validate(db_instance) async def _enqueue_instance_action( instance: Instance, action: JobItemAction, auth_user: AuthUser, session: AsyncSession ) -> Job: job = Job( job_uuid=uuid4().hex, job_type=JOB_TYPE_BY_ACTION[action], status=JobStatus.PENDING, progress=0, total_count=1, created_by_user_id=auth_user.user.id, created_for_customer=instance.customer_id, ) session.add(job) await session.flush() job_item = JobItem( job_id=job.id, resource_type=JobItemResourceType.INSTANCE, resource_id=instance.id, account_id=instance.account_id, region=instance.region, instance_id=instance.instance_id, action=action, status=JobItemStatus.PENDING, ) instance.desired_status = DESIRED_BY_ACTION[action] session.add(job_item) await session.commit() await session.refresh(job) await session.refresh(job_item) asyncio.create_task(_process_instance_action(job.id, job_item.id, action)) return job async def _process_instance_action(job_id: int, job_item_id: int, action: JobItemAction) -> None: async with SessionLocal() as session: job = await session.get(Job, job_id) job_item = await session.scalar( select(JobItem) .where(JobItem.id == job_item_id) .options(selectinload(JobItem.instance).selectinload(Instance.credential)) ) if not job or not job_item or not job_item.instance or not job_item.instance.credential: return job.status = JobStatus.RUNNING job.started_at = datetime.now(timezone.utc) job_item.status = JobItemStatus.RUNNING await session.commit() cred = job_item.instance.credential region = job_item.region or job_item.instance.region instance_id = job_item.instance.instance_id try: if action == JobItemAction.START: resp = await asyncio.to_thread(aws_ops.start_instances, cred, region, [instance_id]) job_item.instance.status = InstanceStatus.RUNNING elif action == JobItemAction.STOP: resp = await asyncio.to_thread(aws_ops.stop_instances, cred, region, [instance_id]) job_item.instance.status = InstanceStatus.STOPPED elif action == JobItemAction.REBOOT: resp = await asyncio.to_thread(aws_ops.reboot_instances, cred, region, [instance_id]) job_item.instance.status = InstanceStatus.RUNNING elif action == JobItemAction.TERMINATE: resp = await asyncio.to_thread(aws_ops.terminate_instances, cred, region, [instance_id]) job_item.instance.status = InstanceStatus.TERMINATED job_item.instance.terminated_at = datetime.now(timezone.utc) else: # pragma: no cover resp = {} job_item.extra = resp job_item.status = JobItemStatus.SUCCESS job.success_count = 1 job.total_count = 1 job.progress = 100 job.status = JobStatus.SUCCESS job.finished_at = datetime.now(timezone.utc) await session.commit() except Exception as exc: # pragma: no cover - AWS failure path job_item.status = JobItemStatus.FAILED job_item.error_message = str(exc) job.status = JobStatus.FAILED job.error_message = str(exc) job.fail_count = 1 job.progress = 100 job.finished_at = datetime.now(timezone.utc) await session.commit() async def _get_instance_or_404(instance_id: int, session: AsyncSession, auth_user: AuthUser) -> Instance: instance = await session.get(Instance, instance_id) if not instance: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Instance not found") if auth_user.role_name != "ADMIN" and instance.customer_id != auth_user.customer_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden") if not instance.credential_id: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Instance missing credential") return instance async def _action_endpoint( instance_id: int, action: JobItemAction, request: Request, session: AsyncSession, auth_user: AuthUser, ) -> JobOut: instance = await _get_instance_or_404(instance_id, session, auth_user) await _ensure_credential_access(instance.credential_id, auth_user, session) job = await _enqueue_instance_action(instance, action, auth_user, session) await create_audit_log( session, user_id=auth_user.user.id, customer_id=instance.customer_id, action={ JobItemAction.START: AuditAction.INSTANCE_START, JobItemAction.STOP: AuditAction.INSTANCE_STOP, JobItemAction.REBOOT: AuditAction.INSTANCE_REBOOT, JobItemAction.TERMINATE: AuditAction.INSTANCE_TERMINATE, }[action], resource_type=AuditResourceType.INSTANCE, resource_id=instance.id, description=f"{action.value} instance {instance.instance_id}", request=request, ) await session.commit() return JobOut.model_validate(job) @router.post("/{instance_id}/start", response_model=JobOut) async def start_instance( instance_id: int, request: Request, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> JobOut: return await _action_endpoint(instance_id, JobItemAction.START, request, session, auth_user) @router.post("/{instance_id}/stop", response_model=JobOut) async def stop_instance( instance_id: int, request: Request, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> JobOut: return await _action_endpoint(instance_id, JobItemAction.STOP, request, session, auth_user) @router.post("/{instance_id}/reboot", response_model=JobOut) async def reboot_instance( instance_id: int, request: Request, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> JobOut: return await _action_endpoint(instance_id, JobItemAction.REBOOT, request, session, auth_user) @router.post("/{instance_id}/terminate", response_model=JobOut) async def terminate_instance( instance_id: int, request: Request, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> JobOut: return await _action_endpoint(instance_id, JobItemAction.TERMINATE, request, session, auth_user) @router.post("/sync", response_model=JobOut) async def sync_instances( payload: InstanceSyncRequest, request: Request, session: AsyncSession = Depends(get_session), auth_user: AuthUser = Depends(get_current_user), ) -> JobOut: target_customer_id = payload.customer_id or auth_user.customer_id if not target_customer_id: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="customer_id required") credentials_query = ( select(AWSCredential) .join(CustomerCredential, CustomerCredential.credential_id == AWSCredential.id) .where(CustomerCredential.customer_id == target_customer_id) .where(CustomerCredential.is_allowed == 1) .where(AWSCredential.is_active == 1) ) if payload.credential_id: credentials_query = credentials_query.where(AWSCredential.id == payload.credential_id) credentials = (await session.scalars(credentials_query)).all() if not credentials: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No credentials to sync") job = Job( job_uuid=uuid4().hex, job_type=JobType.SYNC_INSTANCES, status=JobStatus.RUNNING, progress=0, total_count=0, created_by_user_id=auth_user.user.id, created_for_customer=target_customer_id, payload=payload.model_dump(), started_at=datetime.now(timezone.utc), ) session.add(job) await session.commit() await session.refresh(job) synced_count = 0 now = datetime.now(timezone.utc) try: for cred in credentials: region = payload.region or cred.default_region try: resp = await asyncio.to_thread(aws_ops.describe_instances, cred, region) except Exception as exc: # pragma: no cover continue reservations = resp.get("Reservations") or [] for res in reservations: for inst in res.get("Instances", []): instance_id = inst.get("InstanceId") state_name = (inst.get("State") or {}).get("Name") name_tag = _extract_name_tag(inst.get("Tags")) security_groups = [sg.get("GroupId") for sg in inst.get("SecurityGroups", []) if sg.get("GroupId")] record = dict( customer_id=target_customer_id, credential_id=cred.id, account_id=cred.account_id, region=region, az=(inst.get("Placement") or {}).get("AvailabilityZone"), instance_id=instance_id, name_tag=name_tag, instance_type=inst.get("InstanceType"), ami_id=inst.get("ImageId"), key_name=inst.get("KeyName"), public_ip=inst.get("PublicIpAddress"), private_ip=inst.get("PrivateIpAddress"), status=_map_state(state_name), desired_status=None, security_groups=security_groups, subnet_id=inst.get("SubnetId"), vpc_id=inst.get("VpcId"), launched_at=inst.get("LaunchTime"), terminated_at=now if state_name == "terminated" else None, last_sync=now, last_cloud_state={"state": inst.get("State"), "tags": inst.get("Tags")}, ) stmt = insert(Instance).values(**record) update_cols = {k: stmt.inserted[k] for k in record.keys() if k not in ("id",)} await session.execute(stmt.on_duplicate_key_update(**update_cols)) db_inst = await session.scalar( select(Instance).where( Instance.account_id == cred.account_id, Instance.region == region, Instance.instance_id == instance_id, ) ) session.add( JobItem( job_id=job.id, resource_type=JobItemResourceType.INSTANCE, resource_id=db_inst.id if db_inst else None, account_id=cred.account_id, region=region, instance_id=instance_id, action=JobItemAction.SYNC, status=JobItemStatus.SUCCESS, ) ) synced_count += 1 job.total_count = synced_count job.success_count = synced_count job.status = JobStatus.SUCCESS job.progress = 100 job.finished_at = datetime.now(timezone.utc) await create_audit_log( session, user_id=auth_user.user.id, customer_id=target_customer_id, action=AuditAction.INSTANCE_SYNC, resource_type=AuditResourceType.INSTANCE, resource_id=None, description=f"Sync instances with {synced_count} records", payload=payload.model_dump(), request=request, ) await session.commit() except Exception as exc: # pragma: no cover - sync failure path job.status = JobStatus.FAILED job.error_message = str(exc) job.progress = 100 job.finished_at = datetime.now(timezone.utc) await session.commit() return JobOut.model_validate(job)