from __future__ import annotations import asyncio import logging import re import secrets import string import time from datetime import datetime, timezone from io import BytesIO from typing import Any, Dict, List, Optional from uuid import uuid4 from fastapi import HTTPException, status from sqlalchemy import and_, func, or_, select from sqlalchemy.dialects.mysql import insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from botocore.exceptions import ClientError from openpyxl import Workbook from backend.db.session import AsyncSessionLocal from backend.modules.audit.models import AuditAction, AuditLog, AuditResourceType from backend.modules.aws_accounts.models import AWSCredential, CredentialType, CustomerCredential from backend.modules.instances import aws_ops from backend.modules.instances.models import Instance, InstanceDesiredStatus, InstanceStatus from backend.modules.instances.constants import DEFAULT_LOGIN_PASSWORD, DEFAULT_LOGIN_USERNAME from backend.modules.instances.bootstrap_templates import build_user_data from backend.modules.instances.events import ( broadcast_instance_removed, broadcast_instance_update, build_removed_payload, ) from backend.modules.jobs.models import ( Job, JobItem, JobItemAction, JobItemResourceType, JobItemStatus, JobStatus, JobType, ) from backend.modules.customers.models import Customer from backend.modules.users.models import Role, RoleName, User from backend.core.config import settings from backend.modules.instances.schemas import ( BatchInstancesActionIn, BatchInstancesActionOut, BatchInstancesByIpIn, ) STATE_MAP: Dict[str, InstanceStatus] = { "pending": InstanceStatus.PENDING, "running": InstanceStatus.RUNNING, "stopping": InstanceStatus.STOPPING, "stopped": InstanceStatus.STOPPED, "shutting-down": InstanceStatus.SHUTTING_DOWN, "terminated": InstanceStatus.TERMINATED, } JOB_TYPE_BY_ACTION = { JobItemAction.START: JobType.START_INSTANCES, JobItemAction.STOP: JobType.STOP_INSTANCES, JobItemAction.REBOOT: JobType.REBOOT_INSTANCES, JobItemAction.TERMINATE: JobType.TERMINATE_INSTANCES, } DESIRED_BY_ACTION = { JobItemAction.START: InstanceDesiredStatus.RUNNING, JobItemAction.STOP: InstanceDesiredStatus.STOPPED, JobItemAction.REBOOT: InstanceDesiredStatus.RUNNING, JobItemAction.TERMINATE: InstanceDesiredStatus.TERMINATED, } # cache image name lookups to avoid repeated describe_images calls within a process _image_name_cache: dict[tuple[int, str, str], str | None] = {} _instance_type_specs_cache: dict[tuple[str, int, str], dict[str, tuple[int | None, float | None, str | None]]] = {} _instance_type_cache_ts: dict[tuple[str, int, str], float] = {} INSTANCE_TYPE_CACHE_TTL = 60 * 60 # seconds MAX_BATCH_COUNT = 200 MAX_EXPORT_ROWS = 5000 MAX_EXPORT_IDS = 2000 MIN_SCOPE_SYNC_INTERVAL = max(1, int(settings.scope_sync_min_interval_seconds)) GLOBAL_SYNC_INTERVAL_MINUTES = max(1, int(settings.global_sync_interval_minutes)) GLOBAL_SYNC_MAX_CONCURRENCY = max(1, int(settings.global_sync_max_concurrency)) logger = logging.getLogger(__name__) _scope_sync_last: dict[tuple[str, int, str], float] = {} _scope_sync_locks: dict[tuple[str, int, str], asyncio.Lock] = {} _global_sync_task: Optional[asyncio.Task] = None _scope_sync_semaphore = asyncio.Semaphore(GLOBAL_SYNC_MAX_CONCURRENCY) async def _delete_local_instance(session: AsyncSession, instance: Instance) -> dict: payload = build_removed_payload(instance) await session.delete(instance) return payload def _scope_key(provider: str, credential_id: Optional[int], region: Optional[str]) -> tuple[str, int, str]: return (provider, credential_id or 0, region or "*") def _get_scope_lock(provider: str, credential_id: Optional[int], region: Optional[str]) -> asyncio.Lock: key = _scope_key(provider, credential_id, region) lock = _scope_sync_locks.get(key) if lock is None: lock = asyncio.Lock() _scope_sync_locks[key] = lock return lock async def _pick_sync_actor(session: AsyncSession, actor_id: Optional[int]) -> Optional[User]: if actor_id: user = await session.get(User, actor_id) if user: return user admin = await session.scalar( select(User).join(Role).where(Role.name == RoleName.ADMIN.value).limit(1) ) if admin: return admin return await session.scalar(select(User).limit(1)) async def enqueue_scope_sync(provider: str, credential_id: int, region: str, *, reason: str, actor_id: Optional[int] = None, customer_id: Optional[int] = None) -> bool: """ Enqueue a scope sync job with simple rate limiting to avoid hammering cloud APIs. Returns True if enqueued, False if skipped due to debounce. """ key = _scope_key(provider, credential_id, region) now = time.monotonic() last = _scope_sync_last.get(key, 0) if now - last < MIN_SCOPE_SYNC_INTERVAL: return False _scope_sync_last[key] = now async def _runner(): async with _scope_sync_semaphore: try: async with AsyncSessionLocal() as session: actor = await _pick_sync_actor(session, actor_id) if not actor: logger.warning("scope sync skipped due to missing actor %s/%s/%s", provider, credential_id, region) return await sync_instances( session, credential_id, region, actor, customer_id_override=customer_id, provider=provider, ) except Exception as exc: # pragma: no cover - best-effort sync logger.warning("scope sync failed for %s/%s/%s: %s", provider, credential_id, region, exc) asyncio.create_task(_runner()) logger.info("scope sync enqueued for %s/%s/%s reason=%s", provider, credential_id, region, reason) return True async def _discover_sync_scopes(session: AsyncSession) -> list[tuple[int, str, Optional[int]]]: scopes: set[tuple[int, str, Optional[int]]] = set() rows = await session.execute( select(Instance.credential_id, Instance.region, Instance.customer_id) .where(Instance.credential_id.isnot(None)) .distinct() ) for cred_id, region, customer_id in rows.all(): if cred_id and region: scopes.add((cred_id, region, customer_id)) credentials = (await session.scalars(select(AWSCredential).where(AWSCredential.is_active == 1))).all() for cred in credentials: mappings = ( await session.scalars( select(CustomerCredential.customer_id) .where(CustomerCredential.credential_id == cred.id) .where(CustomerCredential.is_allowed == 1) ) ).all() uniq = sorted({cid for cid in mappings if cid}) if len(uniq) == 1: scopes.add((cred.id, cred.default_region, uniq[0])) return list(scopes) async def _global_sync_loop() -> None: await asyncio.sleep(5) while True: try: async with AsyncSessionLocal() as session: admin = await _pick_sync_actor(session, None) admin_id = admin.id if admin else None scopes = await _discover_sync_scopes(session) for cred_id, region, customer_id in scopes: await enqueue_scope_sync( "aws", cred_id, region, reason="global_cron", actor_id=admin_id, customer_id=customer_id, ) except Exception as exc: # pragma: no cover - background loop resilience logger.warning("global sync loop error: %s", exc) await asyncio.sleep(GLOBAL_SYNC_INTERVAL_MINUTES * 60) def start_global_sync_scheduler() -> None: global _global_sync_task if GLOBAL_SYNC_INTERVAL_MINUTES <= 0: return if _global_sync_task and not _global_sync_task.done(): return _global_sync_task = asyncio.create_task(_global_sync_loop()) def _pretty_os_label(raw: str | None) -> str | None: if not raw: return None text = str(raw).strip() lower = text.lower() def find_version(patterns: list[str]) -> str | None: for pattern in patterns: m = re.search(pattern, lower) if m: return m.group(1) return None if "ubuntu" in lower: version = find_version([r"ubuntu[\w\- ]*(\d{2}\.\d{2})"]) codename_map = {"jammy": "22.04", "focal": "20.04", "bionic": "18.04"} for code, ver in codename_map.items(): if code in lower: version = version or ver label = "Ubuntu" if version: label = f"Ubuntu {version}" if version in {"22.04", "20.04", "18.04"}: label = f"{label} LTS" return label if "debian" in lower: version = find_version([r"debian[\w\- ]*(\d{1,2})", r"debian[\w\- ]*(\d{1,2}\.\d)"]) codename_map = {"bookworm": "12", "bullseye": "11"} for code, ver in codename_map.items(): if code in lower: version = version or ver return f"Debian {version}" if version else "Debian" if "centos" in lower: version = find_version([r"centos[\w\- ]*(\d{1,2})"]) return f"CentOS {version}" if version else "CentOS" if "amazon linux" in lower or lower.startswith("amzn") or "amzn" in lower: version = find_version([r"amazon linux ?(\d+)", r"amzn(\d)"]) return f"Amazon Linux {version}" if version else "Amazon Linux" if "windows" in lower: version = find_version([r"windows[ _\-]*server[ _\-]*(\d{4})", r"windows[ _\-]*(\d{4})"]) if not version: for candidate in ["2022", "2019", "2016", "2012"]: if candidate in lower: version = candidate break base = "Windows Server" if "server" in lower or version else "Windows" return f"{base} {version}".strip() if version else base return text def _chunked(seq: List[str], size: int = 100) -> List[List[str]]: return [seq[i : i + size] for i in range(0, len(seq), size)] async def get_instance_type_specs( provider: str, credential: AWSCredential | None, region: str, instance_types: List[str], ) -> dict[str, tuple[int | None, float | None, str | None]]: if provider != "aws" or not credential or not instance_types: return {} region_use = region or credential.default_region if not region_use: return {} cache_key = (provider, credential.id, region_use) now = time.time() cache_ts = _instance_type_cache_ts.get(cache_key, 0) cache_items = _instance_type_specs_cache.get(cache_key, {}) if now - cache_ts > INSTANCE_TYPE_CACHE_TTL: cache_items = {} needed = [it for it in set(instance_types) if it and it not in cache_items] if needed: fetched = dict(cache_items) for chunk in _chunked(sorted(needed), 100): try: resp = await asyncio.to_thread( aws_ops.describe_instance_types, credential, region_use, [{"Name": "instance-type", "Values": chunk}], ) except Exception as exc: # pragma: no cover - best effort metadata logger.debug("describe_instance_types failed for %s/%s: %s", credential.id, region_use, exc) continue for item in resp: itype = item.get("InstanceType") if not itype: continue vcpu = (item.get("VCpuInfo") or {}).get("DefaultVCpus") mem_mib = (item.get("MemoryInfo") or {}).get("SizeInMiB") mem_gib = round(mem_mib / 1024, 2) if mem_mib is not None else None net_perf = (item.get("NetworkInfo") or {}).get("NetworkPerformance") fetched[itype] = (vcpu, mem_gib, net_perf) cache_items = fetched _instance_type_specs_cache[cache_key] = cache_items _instance_type_cache_ts[cache_key] = now return {it: cache_items[it] for it in instance_types if it in cache_items} async def _lookup_image_name(cred: AWSCredential, region: str, ami_id: str) -> str | None: cache_key = (cred.id, region, ami_id) if cache_key in _image_name_cache: return _image_name_cache[cache_key] name: str | None = None try: resp = await asyncio.to_thread(aws_ops.describe_images, cred, region, [ami_id]) images = resp.get("Images", []) if images: img = images[0] name = img.get("Name") or img.get("Description") except Exception: name = None _image_name_cache[cache_key] = name return name async def _resolve_os_name( cred: AWSCredential, region: str, inst: dict, existing_os: str | None = None ) -> tuple[str | None, str | None]: """ Try to get a user-friendly OS name. Prefer existing value; otherwise use AWS metadata and AMI name. Returns a tuple of (pretty_name, image_name). """ os_hint = existing_os or inst.get("PlatformDetails") or inst.get("Platform") ami_id = inst.get("ImageId") os_name = os_hint image_name: str | None = None lower_name = os_name.lower() if isinstance(os_name, str) else "" needs_lookup = not os_name or lower_name.startswith("linux/unix") or not re.search(r"\d", lower_name) if ami_id and needs_lookup: image_name = await _lookup_image_name(cred, region, ami_id) if image_name: os_name = image_name if not os_name: os_name = ami_id pretty = _pretty_os_label(os_name) if not pretty and existing_os: pretty = _pretty_os_label(existing_os) fallback = os_name or existing_os return pretty or fallback, image_name def _derive_os_pretty_name(inst: Instance) -> str | None: candidates: list[str | None] = [] candidates.append(inst.os_name) if isinstance(inst.last_cloud_state, dict): candidates.append(inst.last_cloud_state.get("os_family")) candidates.append(inst.last_cloud_state.get("ami_name")) for cand in candidates: pretty = _pretty_os_label(cand) if pretty: return pretty for cand in candidates: if cand: return str(cand) return None async def enrich_instances(instances: List[Instance]) -> None: if not instances: return # OS pretty names first so they are present even if specs lookup fails for inst in instances: inst.os_pretty_name = _derive_os_pretty_name(inst) scope_map: dict[tuple[int, str], list[Instance]] = {} cred_map: dict[int, AWSCredential] = {} for inst in instances: if inst.credential and inst.credential_id: cred_map[inst.credential_id] = inst.credential if inst.credential_id and inst.region and inst.instance_type: scope_map.setdefault((inst.credential_id, inst.region), []).append(inst) for (cred_id, region), scoped_instances in scope_map.items(): cred = cred_map.get(cred_id) if not cred: continue types = [i.instance_type for i in scoped_instances if i.instance_type] specs_map = await get_instance_type_specs("aws", cred, region, types) if not specs_map: continue for inst in scoped_instances: spec = specs_map.get(inst.instance_type) if not spec: continue vcpu, mem_gib, net_perf = spec inst.instance_vcpus = vcpu inst.instance_memory_gib = mem_gib inst.instance_network_perf = net_perf def _map_state(state: Optional[str]) -> InstanceStatus: if not state: return InstanceStatus.UNKNOWN return STATE_MAP.get(state.lower(), InstanceStatus.UNKNOWN) def _extract_name_tag(tags: Optional[List[Dict[str, str]]]) -> Optional[str]: if not tags: return None for tag in tags: if tag.get("Key") == "Name": return tag.get("Value") return None async def ensure_credential_access(session: AsyncSession, credential_id: int, actor: User) -> AWSCredential: cred = await session.get(AWSCredential, credential_id) if not cred or not cred.is_active: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found") if actor.role.name == RoleName.ADMIN.value: return cred mapping = await session.scalar( select(CustomerCredential).where( and_( CustomerCredential.customer_id == actor.customer_id, CustomerCredential.credential_id == credential_id, CustomerCredential.is_allowed == 1, ) ) ) if not mapping: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Credential not allowed") return cred def build_instances_query(filters: dict, actor: User): filters = dict(filters or {}) query = select(Instance).options( selectinload(Instance.credential), selectinload(Instance.customer), ) if actor.role.name != RoleName.ADMIN.value: query = query.where(Instance.customer_id == actor.customer_id) elif filters.get("customer_id"): query = query.where(Instance.customer_id == filters["customer_id"]) if filters.get("credential_id"): query = query.where(Instance.credential_id == filters["credential_id"]) if filters.get("account_id"): query = query.where(Instance.account_id == filters["account_id"]) if filters.get("region"): query = query.where(Instance.region == filters["region"]) if filters.get("status"): query = query.where(Instance.status == filters["status"]) if filters.get("instance_ids"): ids = list(dict.fromkeys(filters["instance_ids"])) if ids: query = query.where(Instance.id.in_(ids)) if filters.get("keyword"): pattern = f"%{filters['keyword']}%" query = query.where( or_( Instance.name_tag.ilike(pattern), Instance.instance_id.ilike(pattern), Instance.public_ip.ilike(pattern), Instance.private_ip.ilike(pattern), ) ) return query async def list_instances( session: AsyncSession, filters: dict, actor: User, ) -> tuple[List[Instance], int]: query = build_instances_query(filters, actor) total = await session.scalar(select(func.count()).select_from(query.subquery())) rows = ( await session.scalars( query.order_by(Instance.updated_at.desc()).offset(filters.get("offset", 0)).limit(filters.get("limit", 20)) ) ).all() await enrich_instances(rows) return rows, total or 0 def _generate_random_password(length: int = 16) -> str: alphabet = string.ascii_letters + string.digits return "".join(secrets.choice(alphabet) for _ in range(length)) def _resolve_default_ami(cred: AWSCredential, region: str, os_family: str) -> str: """Fetch latest common AMI for a given OS family/version. Best-effort.""" session, cfg = aws_ops.build_session(cred, region) client = session.client("ec2", region_name=region or cred.default_region, config=cfg) family = (os_family or "ubuntu-22.04").lower() owners: List[str] = [] name_values: List[str] = [] if family.startswith("ubuntu-22.04"): owners, name_values = (["099720109477"], ["ubuntu/images/hvm-ssd/ubuntu-jammy-22.04-amd64-server-*"]) elif family.startswith("ubuntu-20.04"): owners, name_values = (["099720109477"], ["ubuntu/images/hvm-ssd/ubuntu-focal-20.04-amd64-server-*"]) elif family.startswith("debian-12"): owners, name_values = (["136693071363"], ["debian-12-amd64-*"]) elif family.startswith("debian-11"): owners, name_values = (["136693071363"], ["debian-11-amd64-*"]) elif family.startswith("centos-9"): owners, name_values = (["125523088429"], ["CentOS Stream 9*"]) elif family.startswith("amazonlinux-2"): owners, name_values = (["137112412989"], ["amzn2-ami-hvm-*-x86_64-gp2"]) elif family.startswith("windows-2019"): owners, name_values = (["801119661308"], ["Windows_Server-2019-English-Full-Base-*"]) elif family.startswith("windows-2022"): owners, name_values = (["801119661308"], ["Windows_Server-2022-English-Full-Base-*"]) else: owners, name_values = (["099720109477"], ["ubuntu/images/hvm-ssd/ubuntu-jammy-22.04-amd64-server-*"]) try: resp = client.describe_images( Owners=owners, Filters=[{"Name": "name", "Values": name_values}, {"Name": "architecture", "Values": ["x86_64"]}], ) images = sorted(resp.get("Images", []), key=lambda x: x.get("CreationDate", ""), reverse=True) if images: return images[0].get("ImageId") except ClientError as exc: code = (exc.response.get("Error") or {}).get("Code") msg = (exc.response.get("Error") or {}).get("Message") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"获取镜像失败: {code} {msg}") except Exception as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"获取镜像失败: {exc}") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="无法为所选系统找到可用 AMI,请手动填写") def _get_root_device_name(cred: AWSCredential, region: str, ami_id: str) -> str: session, cfg = aws_ops.build_session(cred, region) client = session.client("ec2", region_name=region or cred.default_region, config=cfg) try: resp = client.describe_images(ImageIds=[ami_id]) images = resp.get("Images", []) if images and images[0].get("RootDeviceName"): return images[0]["RootDeviceName"] except Exception: return "/dev/xvda" return "/dev/xvda" async def create_instance( session: AsyncSession, payload: dict, actor: User, ) -> dict: # enforce instance type for customers allowed_customer_types = {"t3.micro", "t3.small", "t3.medium"} is_admin = actor.role.name == RoleName.ADMIN.value if not payload.get("instance_type"): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="instance_type required") if not is_admin and payload.get("instance_type") not in allowed_customer_types: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="客户仅可使用 t3.micro / t3.small / t3.medium 实例类型", ) count = payload.get("count") or 1 if count < 1: count = 1 customer_id = payload.get("customer_id") or actor.customer_id if not customer_id: # derive from credential mapping if there is exactly one allowed customer mappings = ( await session.scalars( select(CustomerCredential.customer_id).where( CustomerCredential.credential_id == payload["credential_id"], CustomerCredential.is_allowed == 1, ) ) ).all() unique_customer_ids = sorted({cid for cid in mappings if cid}) if len(unique_customer_ids) == 1: customer_id = unique_customer_ids[0] else: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="customer_id required") cred = await ensure_credential_access(session, payload["credential_id"], actor) if cred.account_id != payload["account_id"]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="account_id mismatch") region = payload["region"] or cred.default_region # quota check: customer-level + region active count customer = await session.get(Customer, customer_id) if not customer: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Customer not found") quota_limit = customer.quota_instances if customer.quota_instances is not None else 999999 active_count = await session.scalar( select(func.count(Instance.id)).where( Instance.customer_id == customer_id, Instance.status != InstanceStatus.TERMINATED, Instance.desired_status != InstanceDesiredStatus.TERMINATED, Instance.region == region, ) ) available = max(0, quota_limit - (active_count or 0)) if count > available: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"超出配额,当前区域可用数量 {available},请调整数量或联系管理员扩容", ) session_boto, cfg = aws_ops.build_session(cred, region) ec2_client = session_boto.client("ec2", region_name=region or cred.default_region, config=cfg) login_username = payload.get("login_username") or DEFAULT_LOGIN_USERNAME login_password = payload.get("login_password") or DEFAULT_LOGIN_PASSWORD # allow client to request random password by passing "RANDOM" sentinel if isinstance(login_password, str) and login_password.upper() == "RANDOM": login_password = _generate_random_password() user_data_mode = (payload.get("user_data_mode") or "auto_root").lower() os_family = payload.get("os_family") or "ubuntu" user_data_content: Optional[str] = None if user_data_mode == "auto_root": user_data_content = build_user_data(os_family, login_username, login_password) elif user_data_mode == "custom": user_data_content = payload.get("custom_user_data") else: user_data_content = None sg_ids = payload.get("security_group_ids") or [] if not sg_ids: if not settings.auto_open_sg_enabled: # 保守策略:要求前端显式选择 SG,避免无意暴露 raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="security_group_ids required when auto open SG is disabled", ) if not payload.get("subnet_id"): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="subnet_id required to derive VPC") subnet_resp = await asyncio.to_thread(ec2_client.describe_subnets, SubnetIds=[payload["subnet_id"]]) subnets = subnet_resp.get("Subnets", []) if not subnets: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid subnet_id") vpc_id = subnets[0].get("VpcId") if not vpc_id: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Subnet missing VPC") # Risk note: open-all SG is only for testing environments. sg_id = await asyncio.to_thread(aws_ops.ensure_open_all_sg_for_vpc, ec2_client, vpc_id) sg_ids = [sg_id] ami_id = payload.get("ami_id") if not ami_id: ami_id = _resolve_default_ami(cred, region, os_family) volume_type = payload.get("volume_type") or "gp3" volume_size = payload.get("volume_size") or 20 if volume_size <= 0: volume_size = 20 if not is_admin: volume_type = "gp3" cpu_options: Optional[Dict[str, Any]] = None if (payload.get("instance_type") or "").lower().startswith("t"): cpu_options = {"CpuCredits": "standard"} root_device = _get_root_device_name(cred, region, ami_id) block_device_mappings = [ { "DeviceName": root_device, "Ebs": { "VolumeSize": volume_size, "VolumeType": volume_type, "DeleteOnTermination": True, }, } ] try: resp = await asyncio.to_thread( aws_ops.run_instances, cred, region, ami_id, payload["instance_type"], payload.get("key_name"), sg_ids, payload.get("subnet_id"), block_device_mappings, cpu_options, count, count, payload.get("name_tag"), user_data_content, ) except ClientError as exc: # pragma: no cover err = exc.response.get("Error", {}) code = err.get("Code") msg = err.get("Message") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"AWS 错误: {code} {msg}") from exc except Exception as exc: # pragma: no cover raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc instances = resp.get("Instances") or [] if not instances: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="AWS did not return instances") instance_ids = [i.get("InstanceId") for i in instances if i.get("InstanceId")] status_checks_map: Dict[str, dict] = {} # refresh details to fetch public IP if not immediately available if instance_ids: try: describe_resp = await asyncio.to_thread( aws_ops.describe_instances, cred, region, instance_ids=instance_ids ) reservations = describe_resp.get("Reservations", []) tmp = [] for res in reservations: tmp.extend(res.get("Instances", [])) if tmp: instances = tmp status_resp = await asyncio.to_thread(aws_ops.describe_instance_status, cred, region, instance_ids) for st in status_resp.get("InstanceStatuses", []): iid = st.get("InstanceId") status_checks_map[iid] = { "instance": st.get("InstanceStatus", {}).get("Status"), "system": st.get("SystemStatus", {}).get("Status"), } except Exception: pass created_db_instances: List[Instance] = [] now = datetime.now(timezone.utc) for data in instances: name_tag = payload.get("name_tag") or _extract_name_tag(data.get("Tags")) security_groups = [sg.get("GroupId") for sg in data.get("SecurityGroups", []) if sg.get("GroupId")] db_inst = Instance( customer_id=customer_id, credential_id=cred.id, account_id=cred.account_id, region=region, az=(data.get("Placement") or {}).get("AvailabilityZone"), instance_id=data.get("InstanceId"), name_tag=name_tag, instance_type=payload["instance_type"], ami_id=ami_id, os_name=os_family, key_name=payload.get("key_name"), public_ip=data.get("PublicIpAddress"), private_ip=data.get("PrivateIpAddress"), status=_map_state((data.get("State") or {}).get("Name")), desired_status=InstanceDesiredStatus.RUNNING, security_groups=security_groups, subnet_id=data.get("SubnetId"), vpc_id=data.get("VpcId"), launched_at=data.get("LaunchTime") or now, last_sync=now, last_cloud_state={ "state": data.get("State"), "tags": data.get("Tags"), "os_family": os_family, "status_checks": status_checks_map.get(data.get("InstanceId")), }, ) session.add(db_inst) created_db_instances.append(db_inst) await session.commit() for db_inst in created_db_instances: await session.refresh(db_inst) await session.refresh(db_inst, attribute_names=["credential", "customer"]) db_inst = created_db_instances[0] session.add( AuditLog( user_id=actor.id, customer_id=customer_id, action=AuditAction.INSTANCE_CREATE, resource_type=AuditResourceType.INSTANCE, resource_id=db_inst.id, description=f"Create instance {db_inst.instance_id} x{len(created_db_instances)}", payload=payload, ) ) await session.commit() await enrich_instances(created_db_instances) for inst in created_db_instances: await broadcast_instance_update(inst) # 新实例创建后触发 scope 同步,确保云端状态与本地收敛 try: await enqueue_scope_sync( "aws", cred.id, region, reason="create_instance", actor_id=actor.id, customer_id=customer_id, ) except Exception as exc: # pragma: no cover logger.warning("enqueue scope sync after create failed: %s", exc) return {"instance": db_inst, "login_username": login_username, "login_password": login_password} async def _process_action(job_id: int, job_item_id: int, action: JobItemAction) -> None: async with AsyncSessionLocal() as session: job = await session.get(Job, job_id) job_item = await session.scalar( select(JobItem) .where(JobItem.id == job_item_id) .options( selectinload(JobItem.instance).selectinload(Instance.credential), selectinload(JobItem.instance).selectinload(Instance.customer), ) ) if not job or not job_item or not job_item.instance or not job_item.instance.credential: return job.status = JobStatus.RUNNING job.started_at = datetime.now(timezone.utc) job_item.status = JobItemStatus.RUNNING await session.commit() cred = job_item.instance.credential region = job_item.region or job_item.instance.region instance_id = job_item.instance.instance_id removed_payload = None updated_instance: Optional[Instance] = None try: if action == JobItemAction.START: resp = await asyncio.to_thread(aws_ops.start_instances, cred, region, [instance_id]) job_item.instance.status = InstanceStatus.RUNNING elif action == JobItemAction.STOP: resp = await asyncio.to_thread(aws_ops.stop_instances, cred, region, [instance_id]) job_item.instance.status = InstanceStatus.STOPPED elif action == JobItemAction.REBOOT: resp = await asyncio.to_thread(aws_ops.reboot_instances, cred, region, [instance_id]) job_item.instance.status = InstanceStatus.RUNNING elif action == JobItemAction.TERMINATE: resp = await asyncio.to_thread(aws_ops.terminate_instances, cred, region, [instance_id]) job_item.instance.status = InstanceStatus.TERMINATED job_item.instance.terminated_at = datetime.now(timezone.utc) removed_payload = await _delete_local_instance(session, job_item.instance) job_item.resource_id = None else: # pragma: no cover resp = {} if action != JobItemAction.TERMINATE: updated_instance = job_item.instance job_item.extra = resp job_item.status = JobItemStatus.SUCCESS job.success_count = 1 job.total_count = 1 job.progress = 100 job.status = JobStatus.SUCCESS job.finished_at = datetime.now(timezone.utc) await session.commit() if updated_instance: await session.refresh(updated_instance) await broadcast_instance_update(updated_instance) if removed_payload: await broadcast_instance_removed(removed_payload, removed_payload.get("customer_id")) try: scope_customer = ( job_item.instance.customer_id if job_item.instance else removed_payload.get("customer_id") if removed_payload else None ) await enqueue_scope_sync( "aws", cred.id, region, reason=f"job_{action.value}", actor_id=job.created_by_user_id, customer_id=scope_customer, ) except Exception as exc: # pragma: no cover - do not fail job if sync enqueue fails logger.warning("enqueue scope sync after job failed: %s", exc) except Exception as exc: # pragma: no cover job_item.status = JobItemStatus.FAILED job_item.error_message = str(exc) job.status = JobStatus.FAILED job.error_message = str(exc) job.fail_count = 1 job.progress = 100 job.finished_at = datetime.now(timezone.utc) await session.commit() async def enqueue_action(session: AsyncSession, instance: Instance, action: JobItemAction, actor: User) -> Job: job = Job( job_uuid=uuid4().hex, job_type=JOB_TYPE_BY_ACTION[action], status=JobStatus.PENDING, progress=0, total_count=1, created_by_user_id=actor.id, created_for_customer=instance.customer_id, ) session.add(job) await session.flush() job_item = JobItem( job_id=job.id, resource_type=JobItemResourceType.INSTANCE, resource_id=instance.id, account_id=instance.account_id, region=instance.region, instance_id=instance.instance_id, action=action, status=JobItemStatus.PENDING, ) instance.desired_status = DESIRED_BY_ACTION[action] session.add(job_item) await session.commit() await session.refresh(job) await session.refresh(instance) await broadcast_instance_update(instance) asyncio.create_task(_process_action(job.id, job_item.id, action)) return job async def upsert_instance_from_cloud( session: AsyncSession, customer_id: int, cred: AWSCredential, region: str, inst: dict, now: datetime, status_checks: Optional[dict] = None, ) -> Optional[Instance]: instance_id = inst.get("InstanceId") state_name = (inst.get("State") or {}).get("Name") name_tag = _extract_name_tag(inst.get("Tags")) security_groups = [sg.get("GroupId") for sg in inst.get("SecurityGroups", []) if sg.get("GroupId")] existing_inst = await session.scalar( select(Instance).where( Instance.account_id == cred.account_id, Instance.region == region, Instance.instance_id == instance_id, ) ) os_family, ami_name = await _resolve_os_name( cred, region, inst, existing_inst.os_name if existing_inst else None ) record = dict( customer_id=customer_id, credential_id=cred.id, account_id=cred.account_id, region=region, az=(inst.get("Placement") or {}).get("AvailabilityZone"), instance_id=instance_id, name_tag=name_tag, instance_type=inst.get("InstanceType"), ami_id=inst.get("ImageId"), os_name=os_family, key_name=inst.get("KeyName"), public_ip=inst.get("PublicIpAddress"), private_ip=inst.get("PrivateIpAddress"), status=_map_state(state_name), desired_status=None, security_groups=security_groups, subnet_id=inst.get("SubnetId"), vpc_id=inst.get("VpcId"), launched_at=inst.get("LaunchTime"), terminated_at=now if state_name == "terminated" else None, last_sync=now, last_cloud_state={ "state": inst.get("State"), "tags": inst.get("Tags"), "os_family": os_family, "ami_name": ami_name, "status_checks": status_checks, }, ) stmt = insert(Instance).values(**record) update_cols = {k: stmt.inserted[k] for k in record.keys() if k != "id"} await session.execute(stmt.on_duplicate_key_update(**update_cols)) return await session.scalar( select(Instance) .options(selectinload(Instance.credential), selectinload(Instance.customer)) .where( Instance.account_id == cred.account_id, Instance.region == region, Instance.instance_id == instance_id, ) ) async def sync_instances( session: AsyncSession, credential_id: Optional[int], region: Optional[str], actor: User, customer_id_override: Optional[int] = None, provider: str = "aws", ) -> Job: lock = _get_scope_lock(provider, credential_id, region) async with lock: base_customer_id = customer_id_override if actor.role.name == RoleName.ADMIN.value else actor.customer_id credentials_query = select(AWSCredential).where(AWSCredential.is_active == 1) if actor.role.name != RoleName.ADMIN.value: credentials_query = ( credentials_query.join(CustomerCredential, CustomerCredential.credential_id == AWSCredential.id) .where(CustomerCredential.customer_id == actor.customer_id) .where(CustomerCredential.is_allowed == 1) ) if credential_id: credentials_query = credentials_query.where(AWSCredential.id == credential_id) credentials = (await session.scalars(credentials_query)).all() if not credentials: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No credentials to sync") job = Job( job_uuid=uuid4().hex, job_type=JobType.SYNC_INSTANCES, status=JobStatus.RUNNING, progress=0, total_count=0, created_by_user_id=actor.id, created_for_customer=base_customer_id, payload={"credential_id": credential_id, "region": region}, started_at=datetime.now(timezone.utc), ) session.add(job) await session.commit() await session.refresh(job) synced = 0 changed_instances: dict[int, Instance] = {} removed_payloads: list[dict] = [] now = datetime.now(timezone.utc) for cred in credentials: current_customer_id = base_customer_id if actor.role.name == RoleName.ADMIN.value and not current_customer_id: mappings = ( await session.scalars( select(CustomerCredential.customer_id) .where(CustomerCredential.credential_id == cred.id) .where(CustomerCredential.is_allowed == 1) ) ).all() uniq = sorted({cid for cid in mappings if cid}) if len(uniq) == 1: current_customer_id = uniq[0] else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"customer_id required for credential {cred.id}", ) if not current_customer_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="customer_id required", ) target_region = region or cred.default_region try: resp = await asyncio.to_thread(aws_ops.describe_instances, cred, target_region) except Exception as exc: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) cloud_instances: dict[str, dict] = {} for res in resp.get("Reservations", []): for inst in res.get("Instances", []): iid = inst.get("InstanceId") if iid: cloud_instances[iid] = inst status_resp: dict = {} if cloud_instances: status_resp = await asyncio.to_thread( aws_ops.describe_instance_status, cred, target_region, list(cloud_instances.keys()) ) status_map: Dict[str, dict] = {} for st in status_resp.get("InstanceStatuses", []): iid = st.get("InstanceId") status_map[iid] = { "instance": st.get("InstanceStatus", {}).get("Status"), "system": st.get("SystemStatus", {}).get("Status"), } local_instances = ( await session.scalars( select(Instance) .options(selectinload(Instance.credential), selectinload(Instance.customer)) .where( Instance.credential_id == cred.id, Instance.region == target_region, Instance.customer_id == current_customer_id, ) ) ).all() local_map: Dict[str, Instance] = {inst.instance_id: inst for inst in local_instances} for inst_id, inst in cloud_instances.items(): state_name = (inst.get("State") or {}).get("Name") state_lower = state_name.lower() if state_name else "" if state_lower in {"terminated", "shutting-down"}: existing = local_map.pop(inst_id, None) if existing: removed_payloads.append(await _delete_local_instance(session, existing)) continue inst_obj = await upsert_instance_from_cloud( session, current_customer_id, cred, target_region, inst, now, status_map.get(inst_id), ) local_map.pop(inst_id, None) if inst_obj: changed_instances[inst_obj.id] = inst_obj synced += 1 session.add( JobItem( job_id=job.id, resource_type=JobItemResourceType.INSTANCE, resource_id=None, account_id=cred.account_id, region=target_region, instance_id=inst_id, action=JobItemAction.SYNC, status=JobItemStatus.SUCCESS, ) ) # 云端不存在的本地实例需要删除 for orphan in local_map.values(): removed_payloads.append(await _delete_local_instance(session, orphan)) job.total_count = synced job.success_count = synced job.status = JobStatus.SUCCESS job.progress = 100 job.finished_at = datetime.now(timezone.utc) session.add( AuditLog( user_id=actor.id, customer_id=base_customer_id or actor.customer_id, action=AuditAction.INSTANCE_SYNC, resource_type=AuditResourceType.INSTANCE, resource_id=None, description=f"Sync instances {synced}", payload={"credential_id": credential_id, "region": region}, ) ) await session.commit() enriched_instances = list(changed_instances.values()) await enrich_instances(enriched_instances) for inst_obj in enriched_instances: await broadcast_instance_update(inst_obj) for payload in removed_payloads: await broadcast_instance_removed(payload, payload.get("customer_id")) return job async def batch_instances_action( session: AsyncSession, payload: BatchInstancesActionIn, actor: User, ) -> BatchInstancesActionOut: if len(payload.instance_ids) > MAX_BATCH_COUNT: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOO_MANY_ITEMS") requested_ids = list(dict.fromkeys(payload.instance_ids)) if not requested_ids: return BatchInstancesActionOut(action=payload.action, requested=[], accepted=[], skipped=[], errors={}) instances = (await session.scalars(build_instances_query({"instance_ids": requested_ids}, actor))).all() found_map = {inst.id: inst for inst in instances} result = BatchInstancesActionOut(action=payload.action, requested=requested_ids, accepted=[], skipped=[], errors={}) accepted_ids: set[int] = set() skipped_ids: set[int] = set() for iid in requested_ids: if iid not in found_map and iid not in skipped_ids: result.skipped.append(iid) result.errors[str(iid)] = "NOT_FOUND_OR_FORBIDDEN" skipped_ids.add(iid) job_action_map = { "start": JobItemAction.START, "stop": JobItemAction.STOP, "reboot": JobItemAction.REBOOT, "terminate": JobItemAction.TERMINATE, } sync_plan: dict[tuple[int, str, int | None], list[int]] = {} for inst in instances: if actor.role.name != RoleName.ADMIN.value and inst.customer_id != actor.customer_id: if inst.id not in skipped_ids: result.skipped.append(inst.id) result.errors[str(inst.id)] = "NOT_FOUND_OR_FORBIDDEN" skipped_ids.add(inst.id) continue if not inst.credential_id: if inst.id not in skipped_ids: result.skipped.append(inst.id) result.errors[str(inst.id)] = "MISSING_CREDENTIAL" skipped_ids.add(inst.id) continue try: await ensure_credential_access(session, inst.credential_id, actor) except HTTPException: if inst.id not in skipped_ids: result.skipped.append(inst.id) result.errors[str(inst.id)] = "NOT_FOUND_OR_FORBIDDEN" skipped_ids.add(inst.id) continue if payload.action == "sync": key = (inst.credential_id, inst.region, inst.customer_id) sync_plan.setdefault(key, []).append(inst.id) continue job_action = job_action_map.get(payload.action) if not job_action: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported action") try: await enqueue_action(session, inst, job_action, actor) if inst.id not in accepted_ids: accepted_ids.add(inst.id) result.accepted.append(inst.id) except Exception as exc: if inst.id not in skipped_ids: skipped_ids.add(inst.id) result.skipped.append(inst.id) result.errors[str(inst.id)] = str(exc) # handle sync actions per credential/region/customer for key, inst_ids in sync_plan.items(): cred_id, region, customer_id = key try: await sync_instances(session, cred_id, region, actor, customer_id_override=customer_id) for iid in inst_ids: if iid not in accepted_ids: accepted_ids.add(iid) result.accepted.append(iid) except Exception as exc: for iid in inst_ids: if iid not in skipped_ids: skipped_ids.add(iid) result.skipped.append(iid) result.errors[str(iid)] = str(exc) return result async def batch_instances_by_ips( session: AsyncSession, payload: BatchInstancesByIpIn, actor: User, ) -> dict: ips = [] seen = set() for raw in payload.ips: for part in re.split(r"[\s,]+", (raw or "").strip()): if not part or part in seen: continue seen.add(part) ips.append(part) if not ips: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="ips required") if len(ips) > MAX_BATCH_COUNT: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOO_MANY_ITEMS") ip_set = set(ips) filters = { "credential_id": payload.credential_id, "region": payload.region, } query = build_instances_query(filters, actor).where( or_(Instance.public_ip.in_(ip_set), Instance.private_ip.in_(ip_set)) ) instances = (await session.scalars(query)).all() matched_ids = [inst.id for inst in instances] matched_ips: set[str] = set() for inst in instances: if inst.public_ip and inst.public_ip in ip_set: matched_ips.add(inst.public_ip) if inst.private_ip and inst.private_ip in ip_set: matched_ips.add(inst.private_ip) result = await batch_instances_action( session, BatchInstancesActionIn(instance_ids=matched_ids, action=payload.action), actor, ) return { "ips_requested": ips, "ips_matched": list(matched_ips), "ips_unmatched": [ip for ip in ips if ip not in matched_ips], "result": result, } def _fmt(dt: Optional[datetime]) -> str: if not dt: return "" if dt.tzinfo: return dt.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") return dt.strftime("%Y-%m-%d %H:%M:%S") def _fmt_gib(val: float | None) -> str: if val is None: return "" try: as_float = float(val) except (TypeError, ValueError): return str(val) if as_float.is_integer(): return str(int(as_float)) return f"{as_float:.2f}".rstrip("0").rstrip(".") def _build_export_sheet(instances: List[Instance]) -> bytes: wb = Workbook() ws = wb.active ws.title = "Instances" headers = [ "Name", "Credential", "Owner", "Instance ID", "Type", "Instance Type Specs", "AMI / OS", "OS Pretty", "Public IP", "Private IP", "Region", "AZ", "Status", "Desired Status", "Account ID", "Credential ID", "Security Groups", "Subnet ID", "VPC ID", "Launched At", "Created At", "Last Sync", ] ws.append(headers) for inst in instances: os_pretty = inst.os_pretty_name or _derive_os_pretty_name(inst) os_label = os_pretty or inst.os_name if not os_label and isinstance(inst.last_cloud_state, dict): os_label = inst.last_cloud_state.get("os_family") or inst.last_cloud_state.get("ami_name") ami_os = " / ".join([x for x in [os_label, inst.ami_id] if x]) sg_val = "" if isinstance(inst.security_groups, list): sg_val = ", ".join([str(sg) for sg in inst.security_groups]) elif isinstance(inst.security_groups, dict): sg_val = ", ".join([str(v) for v in inst.security_groups.values()]) spec_parts = [] if inst.instance_vcpus is not None: spec_parts.append(f"{inst.instance_vcpus} vCPU") if inst.instance_memory_gib is not None: spec_parts.append(f"{_fmt_gib(inst.instance_memory_gib)} GiB") if inst.instance_network_perf: spec_parts.append(inst.instance_network_perf) type_specs = inst.instance_type if spec_parts: type_specs = f"{inst.instance_type} ({', '.join(spec_parts)})" if inst.instance_type else ", ".join(spec_parts) ws.append( [ inst.name_tag or "", getattr(inst, "credential_label", None) or getattr(inst, "credential_name", None) or (inst.credential.name if getattr(inst, "credential", None) else ""), getattr(inst, "owner_name", None) or getattr(inst, "customer_name", None) or "", inst.instance_id, inst.instance_type, type_specs, ami_os, os_pretty or "", inst.public_ip or "", inst.private_ip or "", inst.region, inst.az or "", inst.status.value if isinstance(inst.status, InstanceStatus) else str(inst.status), inst.desired_status.value if isinstance(inst.desired_status, InstanceDesiredStatus) else inst.desired_status or "", inst.account_id, inst.credential_id or "", sg_val, inst.subnet_id or "", inst.vpc_id or "", _fmt(inst.launched_at), _fmt(inst.created_at), _fmt(inst.last_sync), ] ) buffer = BytesIO() wb.save(buffer) buffer.seek(0) return buffer.getvalue() async def export_instances( session: AsyncSession, filters: dict, actor: User, ) -> bytes: filters = dict(filters or {}) filters.pop("offset", None) filters.pop("limit", None) query = build_instances_query(filters, actor).order_by(Instance.updated_at.desc()) rows = (await session.scalars(query)).all() if len(rows) > MAX_EXPORT_ROWS: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"EXPORT_LIMIT_EXCEEDED:{MAX_EXPORT_ROWS}" ) await enrich_instances(rows) return _build_export_sheet(rows) async def export_instances_by_ids( session: AsyncSession, instance_ids: List[int], actor: User, ) -> bytes: if len(instance_ids) > MAX_EXPORT_IDS: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOO_MANY_ITEMS") ids = list(dict.fromkeys(instance_ids)) if not ids: return _build_export_sheet([]) query = build_instances_query({}, actor).where(Instance.id.in_(ids)).order_by(Instance.updated_at.desc()) rows = (await session.scalars(query)).all() if len(rows) > MAX_EXPORT_ROWS: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"EXPORT_LIMIT_EXCEEDED:{MAX_EXPORT_ROWS}" ) await enrich_instances(rows) return _build_export_sheet(rows)