1443 lines
57 KiB
Python
1443 lines
57 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import re
|
||
import secrets
|
||
import string
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from io import BytesIO
|
||
from typing import Any, Dict, List, Optional
|
||
from uuid import uuid4
|
||
|
||
from fastapi import HTTPException, status
|
||
from sqlalchemy import and_, func, or_, select
|
||
from sqlalchemy.dialects.mysql import insert
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import selectinload
|
||
from botocore.exceptions import ClientError
|
||
|
||
from openpyxl import Workbook
|
||
|
||
from backend.db.session import AsyncSessionLocal
|
||
from backend.modules.audit.models import AuditAction, AuditLog, AuditResourceType
|
||
from backend.modules.aws_accounts.models import AWSCredential, CredentialType, CustomerCredential
|
||
from backend.modules.instances import aws_ops
|
||
from backend.modules.instances.models import Instance, InstanceDesiredStatus, InstanceStatus
|
||
from backend.modules.instances.constants import DEFAULT_LOGIN_PASSWORD, DEFAULT_LOGIN_USERNAME
|
||
from backend.modules.instances.bootstrap_templates import build_user_data
|
||
from backend.modules.instances.events import (
|
||
broadcast_instance_removed,
|
||
broadcast_instance_update,
|
||
build_removed_payload,
|
||
)
|
||
from backend.modules.jobs.models import (
|
||
Job,
|
||
JobItem,
|
||
JobItemAction,
|
||
JobItemResourceType,
|
||
JobItemStatus,
|
||
JobStatus,
|
||
JobType,
|
||
)
|
||
from backend.modules.customers.models import Customer
|
||
from backend.modules.users.models import Role, RoleName, User
|
||
from backend.core.config import settings
|
||
from backend.modules.instances.schemas import (
|
||
BatchInstancesActionIn,
|
||
BatchInstancesActionOut,
|
||
BatchInstancesByIpIn,
|
||
)
|
||
|
||
|
||
STATE_MAP: Dict[str, InstanceStatus] = {
|
||
"pending": InstanceStatus.PENDING,
|
||
"running": InstanceStatus.RUNNING,
|
||
"stopping": InstanceStatus.STOPPING,
|
||
"stopped": InstanceStatus.STOPPED,
|
||
"shutting-down": InstanceStatus.SHUTTING_DOWN,
|
||
"terminated": InstanceStatus.TERMINATED,
|
||
}
|
||
|
||
JOB_TYPE_BY_ACTION = {
|
||
JobItemAction.START: JobType.START_INSTANCES,
|
||
JobItemAction.STOP: JobType.STOP_INSTANCES,
|
||
JobItemAction.REBOOT: JobType.REBOOT_INSTANCES,
|
||
JobItemAction.TERMINATE: JobType.TERMINATE_INSTANCES,
|
||
}
|
||
|
||
DESIRED_BY_ACTION = {
|
||
JobItemAction.START: InstanceDesiredStatus.RUNNING,
|
||
JobItemAction.STOP: InstanceDesiredStatus.STOPPED,
|
||
JobItemAction.REBOOT: InstanceDesiredStatus.RUNNING,
|
||
JobItemAction.TERMINATE: InstanceDesiredStatus.TERMINATED,
|
||
}
|
||
|
||
# cache image name lookups to avoid repeated describe_images calls within a process
|
||
_image_name_cache: dict[tuple[int, str, str], str | None] = {}
|
||
_instance_type_specs_cache: dict[tuple[str, int, str], dict[str, tuple[int | None, float | None, str | None]]] = {}
|
||
_instance_type_cache_ts: dict[tuple[str, int, str], float] = {}
|
||
|
||
INSTANCE_TYPE_CACHE_TTL = 60 * 60 # seconds
|
||
|
||
MAX_BATCH_COUNT = 200
|
||
MAX_EXPORT_ROWS = 5000
|
||
MAX_EXPORT_IDS = 2000
|
||
MIN_SCOPE_SYNC_INTERVAL = max(1, int(settings.scope_sync_min_interval_seconds))
|
||
GLOBAL_SYNC_INTERVAL_MINUTES = max(1, int(settings.global_sync_interval_minutes))
|
||
GLOBAL_SYNC_MAX_CONCURRENCY = max(1, int(settings.global_sync_max_concurrency))
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_scope_sync_last: dict[tuple[str, int, str], float] = {}
|
||
_scope_sync_locks: dict[tuple[str, int, str], asyncio.Lock] = {}
|
||
_global_sync_task: Optional[asyncio.Task] = None
|
||
_scope_sync_semaphore = asyncio.Semaphore(GLOBAL_SYNC_MAX_CONCURRENCY)
|
||
|
||
|
||
async def _delete_local_instance(session: AsyncSession, instance: Instance) -> dict:
|
||
payload = build_removed_payload(instance)
|
||
await session.delete(instance)
|
||
return payload
|
||
|
||
|
||
def _scope_key(provider: str, credential_id: Optional[int], region: Optional[str]) -> tuple[str, int, str]:
|
||
return (provider, credential_id or 0, region or "*")
|
||
|
||
|
||
def _get_scope_lock(provider: str, credential_id: Optional[int], region: Optional[str]) -> asyncio.Lock:
|
||
key = _scope_key(provider, credential_id, region)
|
||
lock = _scope_sync_locks.get(key)
|
||
if lock is None:
|
||
lock = asyncio.Lock()
|
||
_scope_sync_locks[key] = lock
|
||
return lock
|
||
|
||
|
||
async def _pick_sync_actor(session: AsyncSession, actor_id: Optional[int]) -> Optional[User]:
|
||
if actor_id:
|
||
user = await session.get(User, actor_id)
|
||
if user:
|
||
return user
|
||
admin = await session.scalar(
|
||
select(User).join(Role).where(Role.name == RoleName.ADMIN.value).limit(1)
|
||
)
|
||
if admin:
|
||
return admin
|
||
return await session.scalar(select(User).limit(1))
|
||
|
||
|
||
async def enqueue_scope_sync(provider: str, credential_id: int, region: str, *, reason: str, actor_id: Optional[int] = None, customer_id: Optional[int] = None) -> bool:
|
||
"""
|
||
Enqueue a scope sync job with simple rate limiting to avoid hammering cloud APIs.
|
||
Returns True if enqueued, False if skipped due to debounce.
|
||
"""
|
||
key = _scope_key(provider, credential_id, region)
|
||
now = time.monotonic()
|
||
last = _scope_sync_last.get(key, 0)
|
||
if now - last < MIN_SCOPE_SYNC_INTERVAL:
|
||
return False
|
||
_scope_sync_last[key] = now
|
||
|
||
async def _runner():
|
||
async with _scope_sync_semaphore:
|
||
try:
|
||
async with AsyncSessionLocal() as session:
|
||
actor = await _pick_sync_actor(session, actor_id)
|
||
if not actor:
|
||
logger.warning("scope sync skipped due to missing actor %s/%s/%s", provider, credential_id, region)
|
||
return
|
||
await sync_instances(
|
||
session,
|
||
credential_id,
|
||
region,
|
||
actor,
|
||
customer_id_override=customer_id,
|
||
provider=provider,
|
||
)
|
||
except Exception as exc: # pragma: no cover - best-effort sync
|
||
logger.warning("scope sync failed for %s/%s/%s: %s", provider, credential_id, region, exc)
|
||
|
||
asyncio.create_task(_runner())
|
||
logger.info("scope sync enqueued for %s/%s/%s reason=%s", provider, credential_id, region, reason)
|
||
return True
|
||
|
||
|
||
async def _discover_sync_scopes(session: AsyncSession) -> list[tuple[int, str, Optional[int]]]:
|
||
scopes: set[tuple[int, str, Optional[int]]] = set()
|
||
rows = await session.execute(
|
||
select(Instance.credential_id, Instance.region, Instance.customer_id)
|
||
.where(Instance.credential_id.isnot(None))
|
||
.distinct()
|
||
)
|
||
for cred_id, region, customer_id in rows.all():
|
||
if cred_id and region:
|
||
scopes.add((cred_id, region, customer_id))
|
||
credentials = (await session.scalars(select(AWSCredential).where(AWSCredential.is_active == 1))).all()
|
||
for cred in credentials:
|
||
mappings = (
|
||
await session.scalars(
|
||
select(CustomerCredential.customer_id)
|
||
.where(CustomerCredential.credential_id == cred.id)
|
||
.where(CustomerCredential.is_allowed == 1)
|
||
)
|
||
).all()
|
||
uniq = sorted({cid for cid in mappings if cid})
|
||
if len(uniq) == 1:
|
||
scopes.add((cred.id, cred.default_region, uniq[0]))
|
||
return list(scopes)
|
||
|
||
|
||
async def _global_sync_loop() -> None:
|
||
await asyncio.sleep(5)
|
||
while True:
|
||
try:
|
||
async with AsyncSessionLocal() as session:
|
||
admin = await _pick_sync_actor(session, None)
|
||
admin_id = admin.id if admin else None
|
||
scopes = await _discover_sync_scopes(session)
|
||
for cred_id, region, customer_id in scopes:
|
||
await enqueue_scope_sync(
|
||
"aws",
|
||
cred_id,
|
||
region,
|
||
reason="global_cron",
|
||
actor_id=admin_id,
|
||
customer_id=customer_id,
|
||
)
|
||
except Exception as exc: # pragma: no cover - background loop resilience
|
||
logger.warning("global sync loop error: %s", exc)
|
||
await asyncio.sleep(GLOBAL_SYNC_INTERVAL_MINUTES * 60)
|
||
|
||
|
||
def start_global_sync_scheduler() -> None:
|
||
global _global_sync_task
|
||
if GLOBAL_SYNC_INTERVAL_MINUTES <= 0:
|
||
return
|
||
if _global_sync_task and not _global_sync_task.done():
|
||
return
|
||
_global_sync_task = asyncio.create_task(_global_sync_loop())
|
||
|
||
|
||
def _pretty_os_label(raw: str | None) -> str | None:
|
||
if not raw:
|
||
return None
|
||
text = str(raw).strip()
|
||
lower = text.lower()
|
||
|
||
def find_version(patterns: list[str]) -> str | None:
|
||
for pattern in patterns:
|
||
m = re.search(pattern, lower)
|
||
if m:
|
||
return m.group(1)
|
||
return None
|
||
|
||
if "ubuntu" in lower:
|
||
version = find_version([r"ubuntu[\w\- ]*(\d{2}\.\d{2})"])
|
||
codename_map = {"jammy": "22.04", "focal": "20.04", "bionic": "18.04"}
|
||
for code, ver in codename_map.items():
|
||
if code in lower:
|
||
version = version or ver
|
||
label = "Ubuntu"
|
||
if version:
|
||
label = f"Ubuntu {version}"
|
||
if version in {"22.04", "20.04", "18.04"}:
|
||
label = f"{label} LTS"
|
||
return label
|
||
if "debian" in lower:
|
||
version = find_version([r"debian[\w\- ]*(\d{1,2})", r"debian[\w\- ]*(\d{1,2}\.\d)"])
|
||
codename_map = {"bookworm": "12", "bullseye": "11"}
|
||
for code, ver in codename_map.items():
|
||
if code in lower:
|
||
version = version or ver
|
||
return f"Debian {version}" if version else "Debian"
|
||
if "centos" in lower:
|
||
version = find_version([r"centos[\w\- ]*(\d{1,2})"])
|
||
return f"CentOS {version}" if version else "CentOS"
|
||
if "amazon linux" in lower or lower.startswith("amzn") or "amzn" in lower:
|
||
version = find_version([r"amazon linux ?(\d+)", r"amzn(\d)"])
|
||
return f"Amazon Linux {version}" if version else "Amazon Linux"
|
||
if "windows" in lower:
|
||
version = find_version([r"windows[ _\-]*server[ _\-]*(\d{4})", r"windows[ _\-]*(\d{4})"])
|
||
if not version:
|
||
for candidate in ["2022", "2019", "2016", "2012"]:
|
||
if candidate in lower:
|
||
version = candidate
|
||
break
|
||
base = "Windows Server" if "server" in lower or version else "Windows"
|
||
return f"{base} {version}".strip() if version else base
|
||
return text
|
||
|
||
|
||
def _chunked(seq: List[str], size: int = 100) -> List[List[str]]:
|
||
return [seq[i : i + size] for i in range(0, len(seq), size)]
|
||
|
||
|
||
async def get_instance_type_specs(
|
||
provider: str,
|
||
credential: AWSCredential | None,
|
||
region: str,
|
||
instance_types: List[str],
|
||
) -> dict[str, tuple[int | None, float | None, str | None]]:
|
||
if provider != "aws" or not credential or not instance_types:
|
||
return {}
|
||
region_use = region or credential.default_region
|
||
if not region_use:
|
||
return {}
|
||
cache_key = (provider, credential.id, region_use)
|
||
now = time.time()
|
||
cache_ts = _instance_type_cache_ts.get(cache_key, 0)
|
||
cache_items = _instance_type_specs_cache.get(cache_key, {})
|
||
if now - cache_ts > INSTANCE_TYPE_CACHE_TTL:
|
||
cache_items = {}
|
||
needed = [it for it in set(instance_types) if it and it not in cache_items]
|
||
if needed:
|
||
fetched = dict(cache_items)
|
||
for chunk in _chunked(sorted(needed), 100):
|
||
try:
|
||
resp = await asyncio.to_thread(
|
||
aws_ops.describe_instance_types,
|
||
credential,
|
||
region_use,
|
||
[{"Name": "instance-type", "Values": chunk}],
|
||
)
|
||
except Exception as exc: # pragma: no cover - best effort metadata
|
||
logger.debug("describe_instance_types failed for %s/%s: %s", credential.id, region_use, exc)
|
||
continue
|
||
for item in resp:
|
||
itype = item.get("InstanceType")
|
||
if not itype:
|
||
continue
|
||
vcpu = (item.get("VCpuInfo") or {}).get("DefaultVCpus")
|
||
mem_mib = (item.get("MemoryInfo") or {}).get("SizeInMiB")
|
||
mem_gib = round(mem_mib / 1024, 2) if mem_mib is not None else None
|
||
net_perf = (item.get("NetworkInfo") or {}).get("NetworkPerformance")
|
||
fetched[itype] = (vcpu, mem_gib, net_perf)
|
||
cache_items = fetched
|
||
_instance_type_specs_cache[cache_key] = cache_items
|
||
_instance_type_cache_ts[cache_key] = now
|
||
return {it: cache_items[it] for it in instance_types if it in cache_items}
|
||
|
||
|
||
async def _lookup_image_name(cred: AWSCredential, region: str, ami_id: str) -> str | None:
|
||
cache_key = (cred.id, region, ami_id)
|
||
if cache_key in _image_name_cache:
|
||
return _image_name_cache[cache_key]
|
||
name: str | None = None
|
||
try:
|
||
resp = await asyncio.to_thread(aws_ops.describe_images, cred, region, [ami_id])
|
||
images = resp.get("Images", [])
|
||
if images:
|
||
img = images[0]
|
||
name = img.get("Name") or img.get("Description")
|
||
except Exception:
|
||
name = None
|
||
_image_name_cache[cache_key] = name
|
||
return name
|
||
|
||
|
||
async def _resolve_os_name(
|
||
cred: AWSCredential, region: str, inst: dict, existing_os: str | None = None
|
||
) -> tuple[str | None, str | None]:
|
||
"""
|
||
Try to get a user-friendly OS name. Prefer existing value; otherwise use AWS metadata and AMI name.
|
||
Returns a tuple of (pretty_name, image_name).
|
||
"""
|
||
os_hint = existing_os or inst.get("PlatformDetails") or inst.get("Platform")
|
||
ami_id = inst.get("ImageId")
|
||
os_name = os_hint
|
||
image_name: str | None = None
|
||
lower_name = os_name.lower() if isinstance(os_name, str) else ""
|
||
needs_lookup = not os_name or lower_name.startswith("linux/unix") or not re.search(r"\d", lower_name)
|
||
|
||
if ami_id and needs_lookup:
|
||
image_name = await _lookup_image_name(cred, region, ami_id)
|
||
if image_name:
|
||
os_name = image_name
|
||
if not os_name:
|
||
os_name = ami_id
|
||
|
||
pretty = _pretty_os_label(os_name)
|
||
if not pretty and existing_os:
|
||
pretty = _pretty_os_label(existing_os)
|
||
fallback = os_name or existing_os
|
||
return pretty or fallback, image_name
|
||
|
||
|
||
def _derive_os_pretty_name(inst: Instance) -> str | None:
|
||
candidates: list[str | None] = []
|
||
candidates.append(inst.os_name)
|
||
if isinstance(inst.last_cloud_state, dict):
|
||
candidates.append(inst.last_cloud_state.get("os_family"))
|
||
candidates.append(inst.last_cloud_state.get("ami_name"))
|
||
for cand in candidates:
|
||
pretty = _pretty_os_label(cand)
|
||
if pretty:
|
||
return pretty
|
||
for cand in candidates:
|
||
if cand:
|
||
return str(cand)
|
||
return None
|
||
|
||
|
||
async def enrich_instances(instances: List[Instance]) -> None:
|
||
if not instances:
|
||
return
|
||
# OS pretty names first so they are present even if specs lookup fails
|
||
for inst in instances:
|
||
inst.os_pretty_name = _derive_os_pretty_name(inst)
|
||
|
||
scope_map: dict[tuple[int, str], list[Instance]] = {}
|
||
cred_map: dict[int, AWSCredential] = {}
|
||
for inst in instances:
|
||
if inst.credential and inst.credential_id:
|
||
cred_map[inst.credential_id] = inst.credential
|
||
if inst.credential_id and inst.region and inst.instance_type:
|
||
scope_map.setdefault((inst.credential_id, inst.region), []).append(inst)
|
||
|
||
for (cred_id, region), scoped_instances in scope_map.items():
|
||
cred = cred_map.get(cred_id)
|
||
if not cred:
|
||
continue
|
||
types = [i.instance_type for i in scoped_instances if i.instance_type]
|
||
specs_map = await get_instance_type_specs("aws", cred, region, types)
|
||
if not specs_map:
|
||
continue
|
||
for inst in scoped_instances:
|
||
spec = specs_map.get(inst.instance_type)
|
||
if not spec:
|
||
continue
|
||
vcpu, mem_gib, net_perf = spec
|
||
inst.instance_vcpus = vcpu
|
||
inst.instance_memory_gib = mem_gib
|
||
inst.instance_network_perf = net_perf
|
||
|
||
|
||
def _map_state(state: Optional[str]) -> InstanceStatus:
|
||
if not state:
|
||
return InstanceStatus.UNKNOWN
|
||
return STATE_MAP.get(state.lower(), InstanceStatus.UNKNOWN)
|
||
|
||
|
||
def _extract_name_tag(tags: Optional[List[Dict[str, str]]]) -> Optional[str]:
|
||
if not tags:
|
||
return None
|
||
for tag in tags:
|
||
if tag.get("Key") == "Name":
|
||
return tag.get("Value")
|
||
return None
|
||
|
||
|
||
async def ensure_credential_access(session: AsyncSession, credential_id: int, actor: User) -> AWSCredential:
|
||
cred = await session.get(AWSCredential, credential_id)
|
||
if not cred or not cred.is_active:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found")
|
||
if actor.role.name == RoleName.ADMIN.value:
|
||
return cred
|
||
mapping = await session.scalar(
|
||
select(CustomerCredential).where(
|
||
and_(
|
||
CustomerCredential.customer_id == actor.customer_id,
|
||
CustomerCredential.credential_id == credential_id,
|
||
CustomerCredential.is_allowed == 1,
|
||
)
|
||
)
|
||
)
|
||
if not mapping:
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Credential not allowed")
|
||
return cred
|
||
|
||
|
||
def build_instances_query(filters: dict, actor: User):
|
||
filters = dict(filters or {})
|
||
query = select(Instance).options(
|
||
selectinload(Instance.credential),
|
||
selectinload(Instance.customer),
|
||
)
|
||
if actor.role.name != RoleName.ADMIN.value:
|
||
query = query.where(Instance.customer_id == actor.customer_id)
|
||
elif filters.get("customer_id"):
|
||
query = query.where(Instance.customer_id == filters["customer_id"])
|
||
if filters.get("credential_id"):
|
||
query = query.where(Instance.credential_id == filters["credential_id"])
|
||
if filters.get("account_id"):
|
||
query = query.where(Instance.account_id == filters["account_id"])
|
||
if filters.get("region"):
|
||
query = query.where(Instance.region == filters["region"])
|
||
if filters.get("status"):
|
||
query = query.where(Instance.status == filters["status"])
|
||
if filters.get("instance_ids"):
|
||
ids = list(dict.fromkeys(filters["instance_ids"]))
|
||
if ids:
|
||
query = query.where(Instance.id.in_(ids))
|
||
if filters.get("keyword"):
|
||
pattern = f"%{filters['keyword']}%"
|
||
query = query.where(
|
||
or_(
|
||
Instance.name_tag.ilike(pattern),
|
||
Instance.instance_id.ilike(pattern),
|
||
Instance.public_ip.ilike(pattern),
|
||
Instance.private_ip.ilike(pattern),
|
||
)
|
||
)
|
||
return query
|
||
|
||
|
||
async def list_instances(
|
||
session: AsyncSession,
|
||
filters: dict,
|
||
actor: User,
|
||
) -> tuple[List[Instance], int]:
|
||
query = build_instances_query(filters, actor)
|
||
total = await session.scalar(select(func.count()).select_from(query.subquery()))
|
||
rows = (
|
||
await session.scalars(
|
||
query.order_by(Instance.updated_at.desc()).offset(filters.get("offset", 0)).limit(filters.get("limit", 20))
|
||
)
|
||
).all()
|
||
await enrich_instances(rows)
|
||
return rows, total or 0
|
||
|
||
|
||
def _generate_random_password(length: int = 16) -> str:
|
||
alphabet = string.ascii_letters + string.digits
|
||
return "".join(secrets.choice(alphabet) for _ in range(length))
|
||
|
||
|
||
def _resolve_default_ami(cred: AWSCredential, region: str, os_family: str) -> str:
|
||
"""Fetch latest common AMI for a given OS family/version. Best-effort."""
|
||
session, cfg = aws_ops.build_session(cred, region)
|
||
client = session.client("ec2", region_name=region or cred.default_region, config=cfg)
|
||
family = (os_family or "ubuntu-22.04").lower()
|
||
owners: List[str] = []
|
||
name_values: List[str] = []
|
||
if family.startswith("ubuntu-22.04"):
|
||
owners, name_values = (["099720109477"], ["ubuntu/images/hvm-ssd/ubuntu-jammy-22.04-amd64-server-*"])
|
||
elif family.startswith("ubuntu-20.04"):
|
||
owners, name_values = (["099720109477"], ["ubuntu/images/hvm-ssd/ubuntu-focal-20.04-amd64-server-*"])
|
||
elif family.startswith("debian-12"):
|
||
owners, name_values = (["136693071363"], ["debian-12-amd64-*"])
|
||
elif family.startswith("debian-11"):
|
||
owners, name_values = (["136693071363"], ["debian-11-amd64-*"])
|
||
elif family.startswith("centos-9"):
|
||
owners, name_values = (["125523088429"], ["CentOS Stream 9*"])
|
||
elif family.startswith("amazonlinux-2"):
|
||
owners, name_values = (["137112412989"], ["amzn2-ami-hvm-*-x86_64-gp2"])
|
||
elif family.startswith("windows-2019"):
|
||
owners, name_values = (["801119661308"], ["Windows_Server-2019-English-Full-Base-*"])
|
||
elif family.startswith("windows-2022"):
|
||
owners, name_values = (["801119661308"], ["Windows_Server-2022-English-Full-Base-*"])
|
||
else:
|
||
owners, name_values = (["099720109477"], ["ubuntu/images/hvm-ssd/ubuntu-jammy-22.04-amd64-server-*"])
|
||
|
||
try:
|
||
resp = client.describe_images(
|
||
Owners=owners,
|
||
Filters=[{"Name": "name", "Values": name_values}, {"Name": "architecture", "Values": ["x86_64"]}],
|
||
)
|
||
images = sorted(resp.get("Images", []), key=lambda x: x.get("CreationDate", ""), reverse=True)
|
||
if images:
|
||
return images[0].get("ImageId")
|
||
except ClientError as exc:
|
||
code = (exc.response.get("Error") or {}).get("Code")
|
||
msg = (exc.response.get("Error") or {}).get("Message")
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"获取镜像失败: {code} {msg}")
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"获取镜像失败: {exc}")
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="无法为所选系统找到可用 AMI,请手动填写")
|
||
|
||
|
||
def _get_root_device_name(cred: AWSCredential, region: str, ami_id: str) -> str:
|
||
session, cfg = aws_ops.build_session(cred, region)
|
||
client = session.client("ec2", region_name=region or cred.default_region, config=cfg)
|
||
try:
|
||
resp = client.describe_images(ImageIds=[ami_id])
|
||
images = resp.get("Images", [])
|
||
if images and images[0].get("RootDeviceName"):
|
||
return images[0]["RootDeviceName"]
|
||
except Exception:
|
||
return "/dev/xvda"
|
||
return "/dev/xvda"
|
||
|
||
|
||
async def create_instance(
|
||
session: AsyncSession,
|
||
payload: dict,
|
||
actor: User,
|
||
) -> dict:
|
||
# enforce instance type for customers
|
||
allowed_customer_types = {"t3.micro", "t3.small", "t3.medium"}
|
||
is_admin = actor.role.name == RoleName.ADMIN.value
|
||
if not payload.get("instance_type"):
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="instance_type required")
|
||
if not is_admin and payload.get("instance_type") not in allowed_customer_types:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="客户仅可使用 t3.micro / t3.small / t3.medium 实例类型",
|
||
)
|
||
count = payload.get("count") or 1
|
||
if count < 1:
|
||
count = 1
|
||
|
||
customer_id = payload.get("customer_id") or actor.customer_id
|
||
if not customer_id:
|
||
# derive from credential mapping if there is exactly one allowed customer
|
||
mappings = (
|
||
await session.scalars(
|
||
select(CustomerCredential.customer_id).where(
|
||
CustomerCredential.credential_id == payload["credential_id"],
|
||
CustomerCredential.is_allowed == 1,
|
||
)
|
||
)
|
||
).all()
|
||
unique_customer_ids = sorted({cid for cid in mappings if cid})
|
||
if len(unique_customer_ids) == 1:
|
||
customer_id = unique_customer_ids[0]
|
||
else:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="customer_id required")
|
||
cred = await ensure_credential_access(session, payload["credential_id"], actor)
|
||
if cred.account_id != payload["account_id"]:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="account_id mismatch")
|
||
|
||
region = payload["region"] or cred.default_region
|
||
|
||
# quota check: customer-level + region active count
|
||
customer = await session.get(Customer, customer_id)
|
||
if not customer:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Customer not found")
|
||
quota_limit = customer.quota_instances if customer.quota_instances is not None else 999999
|
||
active_count = await session.scalar(
|
||
select(func.count(Instance.id)).where(
|
||
Instance.customer_id == customer_id,
|
||
Instance.status != InstanceStatus.TERMINATED,
|
||
Instance.desired_status != InstanceDesiredStatus.TERMINATED,
|
||
Instance.region == region,
|
||
)
|
||
)
|
||
available = max(0, quota_limit - (active_count or 0))
|
||
if count > available:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"超出配额,当前区域可用数量 {available},请调整数量或联系管理员扩容",
|
||
)
|
||
|
||
session_boto, cfg = aws_ops.build_session(cred, region)
|
||
ec2_client = session_boto.client("ec2", region_name=region or cred.default_region, config=cfg)
|
||
|
||
login_username = payload.get("login_username") or DEFAULT_LOGIN_USERNAME
|
||
login_password = payload.get("login_password") or DEFAULT_LOGIN_PASSWORD
|
||
# allow client to request random password by passing "RANDOM" sentinel
|
||
if isinstance(login_password, str) and login_password.upper() == "RANDOM":
|
||
login_password = _generate_random_password()
|
||
|
||
user_data_mode = (payload.get("user_data_mode") or "auto_root").lower()
|
||
os_family = payload.get("os_family") or "ubuntu"
|
||
user_data_content: Optional[str] = None
|
||
if user_data_mode == "auto_root":
|
||
user_data_content = build_user_data(os_family, login_username, login_password)
|
||
elif user_data_mode == "custom":
|
||
user_data_content = payload.get("custom_user_data")
|
||
else:
|
||
user_data_content = None
|
||
|
||
sg_ids = payload.get("security_group_ids") or []
|
||
if not sg_ids:
|
||
if not settings.auto_open_sg_enabled:
|
||
# 保守策略:要求前端显式选择 SG,避免无意暴露
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="security_group_ids required when auto open SG is disabled",
|
||
)
|
||
if not payload.get("subnet_id"):
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="subnet_id required to derive VPC")
|
||
subnet_resp = await asyncio.to_thread(ec2_client.describe_subnets, SubnetIds=[payload["subnet_id"]])
|
||
subnets = subnet_resp.get("Subnets", [])
|
||
if not subnets:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid subnet_id")
|
||
vpc_id = subnets[0].get("VpcId")
|
||
if not vpc_id:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Subnet missing VPC")
|
||
# Risk note: open-all SG is only for testing environments.
|
||
sg_id = await asyncio.to_thread(aws_ops.ensure_open_all_sg_for_vpc, ec2_client, vpc_id)
|
||
sg_ids = [sg_id]
|
||
|
||
ami_id = payload.get("ami_id")
|
||
if not ami_id:
|
||
ami_id = _resolve_default_ami(cred, region, os_family)
|
||
|
||
volume_type = payload.get("volume_type") or "gp3"
|
||
volume_size = payload.get("volume_size") or 20
|
||
if volume_size <= 0:
|
||
volume_size = 20
|
||
if not is_admin:
|
||
volume_type = "gp3"
|
||
cpu_options: Optional[Dict[str, Any]] = None
|
||
if (payload.get("instance_type") or "").lower().startswith("t"):
|
||
cpu_options = {"CpuCredits": "standard"}
|
||
root_device = _get_root_device_name(cred, region, ami_id)
|
||
block_device_mappings = [
|
||
{
|
||
"DeviceName": root_device,
|
||
"Ebs": {
|
||
"VolumeSize": volume_size,
|
||
"VolumeType": volume_type,
|
||
"DeleteOnTermination": True,
|
||
},
|
||
}
|
||
]
|
||
|
||
try:
|
||
resp = await asyncio.to_thread(
|
||
aws_ops.run_instances,
|
||
cred,
|
||
region,
|
||
ami_id,
|
||
payload["instance_type"],
|
||
payload.get("key_name"),
|
||
sg_ids,
|
||
payload.get("subnet_id"),
|
||
block_device_mappings,
|
||
cpu_options,
|
||
count,
|
||
count,
|
||
payload.get("name_tag"),
|
||
user_data_content,
|
||
)
|
||
except ClientError as exc: # pragma: no cover
|
||
err = exc.response.get("Error", {})
|
||
code = err.get("Code")
|
||
msg = err.get("Message")
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"AWS 错误: {code} {msg}") from exc
|
||
except Exception as exc: # pragma: no cover
|
||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
||
|
||
instances = resp.get("Instances") or []
|
||
if not instances:
|
||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="AWS did not return instances")
|
||
instance_ids = [i.get("InstanceId") for i in instances if i.get("InstanceId")]
|
||
status_checks_map: Dict[str, dict] = {}
|
||
# refresh details to fetch public IP if not immediately available
|
||
if instance_ids:
|
||
try:
|
||
describe_resp = await asyncio.to_thread(
|
||
aws_ops.describe_instances, cred, region, instance_ids=instance_ids
|
||
)
|
||
reservations = describe_resp.get("Reservations", [])
|
||
tmp = []
|
||
for res in reservations:
|
||
tmp.extend(res.get("Instances", []))
|
||
if tmp:
|
||
instances = tmp
|
||
status_resp = await asyncio.to_thread(aws_ops.describe_instance_status, cred, region, instance_ids)
|
||
for st in status_resp.get("InstanceStatuses", []):
|
||
iid = st.get("InstanceId")
|
||
status_checks_map[iid] = {
|
||
"instance": st.get("InstanceStatus", {}).get("Status"),
|
||
"system": st.get("SystemStatus", {}).get("Status"),
|
||
}
|
||
except Exception:
|
||
pass
|
||
|
||
created_db_instances: List[Instance] = []
|
||
now = datetime.now(timezone.utc)
|
||
for data in instances:
|
||
name_tag = payload.get("name_tag") or _extract_name_tag(data.get("Tags"))
|
||
security_groups = [sg.get("GroupId") for sg in data.get("SecurityGroups", []) if sg.get("GroupId")]
|
||
db_inst = Instance(
|
||
customer_id=customer_id,
|
||
credential_id=cred.id,
|
||
account_id=cred.account_id,
|
||
region=region,
|
||
az=(data.get("Placement") or {}).get("AvailabilityZone"),
|
||
instance_id=data.get("InstanceId"),
|
||
name_tag=name_tag,
|
||
instance_type=payload["instance_type"],
|
||
ami_id=ami_id,
|
||
os_name=os_family,
|
||
key_name=payload.get("key_name"),
|
||
public_ip=data.get("PublicIpAddress"),
|
||
private_ip=data.get("PrivateIpAddress"),
|
||
status=_map_state((data.get("State") or {}).get("Name")),
|
||
desired_status=InstanceDesiredStatus.RUNNING,
|
||
security_groups=security_groups,
|
||
subnet_id=data.get("SubnetId"),
|
||
vpc_id=data.get("VpcId"),
|
||
launched_at=data.get("LaunchTime") or now,
|
||
last_sync=now,
|
||
last_cloud_state={
|
||
"state": data.get("State"),
|
||
"tags": data.get("Tags"),
|
||
"os_family": os_family,
|
||
"status_checks": status_checks_map.get(data.get("InstanceId")),
|
||
},
|
||
)
|
||
session.add(db_inst)
|
||
created_db_instances.append(db_inst)
|
||
await session.commit()
|
||
for db_inst in created_db_instances:
|
||
await session.refresh(db_inst)
|
||
await session.refresh(db_inst, attribute_names=["credential", "customer"])
|
||
db_inst = created_db_instances[0]
|
||
session.add(
|
||
AuditLog(
|
||
user_id=actor.id,
|
||
customer_id=customer_id,
|
||
action=AuditAction.INSTANCE_CREATE,
|
||
resource_type=AuditResourceType.INSTANCE,
|
||
resource_id=db_inst.id,
|
||
description=f"Create instance {db_inst.instance_id} x{len(created_db_instances)}",
|
||
payload=payload,
|
||
)
|
||
)
|
||
await session.commit()
|
||
await enrich_instances(created_db_instances)
|
||
for inst in created_db_instances:
|
||
await broadcast_instance_update(inst)
|
||
# 新实例创建后触发 scope 同步,确保云端状态与本地收敛
|
||
try:
|
||
await enqueue_scope_sync(
|
||
"aws",
|
||
cred.id,
|
||
region,
|
||
reason="create_instance",
|
||
actor_id=actor.id,
|
||
customer_id=customer_id,
|
||
)
|
||
except Exception as exc: # pragma: no cover
|
||
logger.warning("enqueue scope sync after create failed: %s", exc)
|
||
return {"instance": db_inst, "login_username": login_username, "login_password": login_password}
|
||
|
||
|
||
async def _process_action(job_id: int, job_item_id: int, action: JobItemAction) -> None:
|
||
async with AsyncSessionLocal() as session:
|
||
job = await session.get(Job, job_id)
|
||
job_item = await session.scalar(
|
||
select(JobItem)
|
||
.where(JobItem.id == job_item_id)
|
||
.options(
|
||
selectinload(JobItem.instance).selectinload(Instance.credential),
|
||
selectinload(JobItem.instance).selectinload(Instance.customer),
|
||
)
|
||
)
|
||
if not job or not job_item or not job_item.instance or not job_item.instance.credential:
|
||
return
|
||
job.status = JobStatus.RUNNING
|
||
job.started_at = datetime.now(timezone.utc)
|
||
job_item.status = JobItemStatus.RUNNING
|
||
await session.commit()
|
||
|
||
cred = job_item.instance.credential
|
||
region = job_item.region or job_item.instance.region
|
||
instance_id = job_item.instance.instance_id
|
||
removed_payload = None
|
||
updated_instance: Optional[Instance] = None
|
||
try:
|
||
if action == JobItemAction.START:
|
||
resp = await asyncio.to_thread(aws_ops.start_instances, cred, region, [instance_id])
|
||
job_item.instance.status = InstanceStatus.RUNNING
|
||
elif action == JobItemAction.STOP:
|
||
resp = await asyncio.to_thread(aws_ops.stop_instances, cred, region, [instance_id])
|
||
job_item.instance.status = InstanceStatus.STOPPED
|
||
elif action == JobItemAction.REBOOT:
|
||
resp = await asyncio.to_thread(aws_ops.reboot_instances, cred, region, [instance_id])
|
||
job_item.instance.status = InstanceStatus.RUNNING
|
||
elif action == JobItemAction.TERMINATE:
|
||
resp = await asyncio.to_thread(aws_ops.terminate_instances, cred, region, [instance_id])
|
||
job_item.instance.status = InstanceStatus.TERMINATED
|
||
job_item.instance.terminated_at = datetime.now(timezone.utc)
|
||
removed_payload = await _delete_local_instance(session, job_item.instance)
|
||
job_item.resource_id = None
|
||
else: # pragma: no cover
|
||
resp = {}
|
||
|
||
if action != JobItemAction.TERMINATE:
|
||
updated_instance = job_item.instance
|
||
|
||
job_item.extra = resp
|
||
job_item.status = JobItemStatus.SUCCESS
|
||
job.success_count = 1
|
||
job.total_count = 1
|
||
job.progress = 100
|
||
job.status = JobStatus.SUCCESS
|
||
job.finished_at = datetime.now(timezone.utc)
|
||
await session.commit()
|
||
if updated_instance:
|
||
await session.refresh(updated_instance)
|
||
await broadcast_instance_update(updated_instance)
|
||
if removed_payload:
|
||
await broadcast_instance_removed(removed_payload, removed_payload.get("customer_id"))
|
||
try:
|
||
scope_customer = (
|
||
job_item.instance.customer_id
|
||
if job_item.instance
|
||
else removed_payload.get("customer_id") if removed_payload else None
|
||
)
|
||
await enqueue_scope_sync(
|
||
"aws",
|
||
cred.id,
|
||
region,
|
||
reason=f"job_{action.value}",
|
||
actor_id=job.created_by_user_id,
|
||
customer_id=scope_customer,
|
||
)
|
||
except Exception as exc: # pragma: no cover - do not fail job if sync enqueue fails
|
||
logger.warning("enqueue scope sync after job failed: %s", exc)
|
||
except Exception as exc: # pragma: no cover
|
||
job_item.status = JobItemStatus.FAILED
|
||
job_item.error_message = str(exc)
|
||
job.status = JobStatus.FAILED
|
||
job.error_message = str(exc)
|
||
job.fail_count = 1
|
||
job.progress = 100
|
||
job.finished_at = datetime.now(timezone.utc)
|
||
await session.commit()
|
||
|
||
|
||
async def enqueue_action(session: AsyncSession, instance: Instance, action: JobItemAction, actor: User) -> Job:
|
||
job = Job(
|
||
job_uuid=uuid4().hex,
|
||
job_type=JOB_TYPE_BY_ACTION[action],
|
||
status=JobStatus.PENDING,
|
||
progress=0,
|
||
total_count=1,
|
||
created_by_user_id=actor.id,
|
||
created_for_customer=instance.customer_id,
|
||
)
|
||
session.add(job)
|
||
await session.flush()
|
||
job_item = JobItem(
|
||
job_id=job.id,
|
||
resource_type=JobItemResourceType.INSTANCE,
|
||
resource_id=instance.id,
|
||
account_id=instance.account_id,
|
||
region=instance.region,
|
||
instance_id=instance.instance_id,
|
||
action=action,
|
||
status=JobItemStatus.PENDING,
|
||
)
|
||
instance.desired_status = DESIRED_BY_ACTION[action]
|
||
session.add(job_item)
|
||
await session.commit()
|
||
await session.refresh(job)
|
||
await session.refresh(instance)
|
||
await broadcast_instance_update(instance)
|
||
asyncio.create_task(_process_action(job.id, job_item.id, action))
|
||
return job
|
||
|
||
|
||
async def upsert_instance_from_cloud(
|
||
session: AsyncSession,
|
||
customer_id: int,
|
||
cred: AWSCredential,
|
||
region: str,
|
||
inst: dict,
|
||
now: datetime,
|
||
status_checks: Optional[dict] = None,
|
||
) -> Optional[Instance]:
|
||
instance_id = inst.get("InstanceId")
|
||
state_name = (inst.get("State") or {}).get("Name")
|
||
name_tag = _extract_name_tag(inst.get("Tags"))
|
||
security_groups = [sg.get("GroupId") for sg in inst.get("SecurityGroups", []) if sg.get("GroupId")]
|
||
existing_inst = await session.scalar(
|
||
select(Instance).where(
|
||
Instance.account_id == cred.account_id,
|
||
Instance.region == region,
|
||
Instance.instance_id == instance_id,
|
||
)
|
||
)
|
||
os_family, ami_name = await _resolve_os_name(
|
||
cred, region, inst, existing_inst.os_name if existing_inst else None
|
||
)
|
||
record = dict(
|
||
customer_id=customer_id,
|
||
credential_id=cred.id,
|
||
account_id=cred.account_id,
|
||
region=region,
|
||
az=(inst.get("Placement") or {}).get("AvailabilityZone"),
|
||
instance_id=instance_id,
|
||
name_tag=name_tag,
|
||
instance_type=inst.get("InstanceType"),
|
||
ami_id=inst.get("ImageId"),
|
||
os_name=os_family,
|
||
key_name=inst.get("KeyName"),
|
||
public_ip=inst.get("PublicIpAddress"),
|
||
private_ip=inst.get("PrivateIpAddress"),
|
||
status=_map_state(state_name),
|
||
desired_status=None,
|
||
security_groups=security_groups,
|
||
subnet_id=inst.get("SubnetId"),
|
||
vpc_id=inst.get("VpcId"),
|
||
launched_at=inst.get("LaunchTime"),
|
||
terminated_at=now if state_name == "terminated" else None,
|
||
last_sync=now,
|
||
last_cloud_state={
|
||
"state": inst.get("State"),
|
||
"tags": inst.get("Tags"),
|
||
"os_family": os_family,
|
||
"ami_name": ami_name,
|
||
"status_checks": status_checks,
|
||
},
|
||
)
|
||
stmt = insert(Instance).values(**record)
|
||
update_cols = {k: stmt.inserted[k] for k in record.keys() if k != "id"}
|
||
await session.execute(stmt.on_duplicate_key_update(**update_cols))
|
||
return await session.scalar(
|
||
select(Instance)
|
||
.options(selectinload(Instance.credential), selectinload(Instance.customer))
|
||
.where(
|
||
Instance.account_id == cred.account_id,
|
||
Instance.region == region,
|
||
Instance.instance_id == instance_id,
|
||
)
|
||
)
|
||
|
||
|
||
async def sync_instances(
|
||
session: AsyncSession,
|
||
credential_id: Optional[int],
|
||
region: Optional[str],
|
||
actor: User,
|
||
customer_id_override: Optional[int] = None,
|
||
provider: str = "aws",
|
||
) -> Job:
|
||
lock = _get_scope_lock(provider, credential_id, region)
|
||
async with lock:
|
||
base_customer_id = customer_id_override if actor.role.name == RoleName.ADMIN.value else actor.customer_id
|
||
credentials_query = select(AWSCredential).where(AWSCredential.is_active == 1)
|
||
if actor.role.name != RoleName.ADMIN.value:
|
||
credentials_query = (
|
||
credentials_query.join(CustomerCredential, CustomerCredential.credential_id == AWSCredential.id)
|
||
.where(CustomerCredential.customer_id == actor.customer_id)
|
||
.where(CustomerCredential.is_allowed == 1)
|
||
)
|
||
if credential_id:
|
||
credentials_query = credentials_query.where(AWSCredential.id == credential_id)
|
||
credentials = (await session.scalars(credentials_query)).all()
|
||
if not credentials:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No credentials to sync")
|
||
|
||
job = Job(
|
||
job_uuid=uuid4().hex,
|
||
job_type=JobType.SYNC_INSTANCES,
|
||
status=JobStatus.RUNNING,
|
||
progress=0,
|
||
total_count=0,
|
||
created_by_user_id=actor.id,
|
||
created_for_customer=base_customer_id,
|
||
payload={"credential_id": credential_id, "region": region},
|
||
started_at=datetime.now(timezone.utc),
|
||
)
|
||
session.add(job)
|
||
await session.commit()
|
||
await session.refresh(job)
|
||
|
||
synced = 0
|
||
changed_instances: dict[int, Instance] = {}
|
||
removed_payloads: list[dict] = []
|
||
now = datetime.now(timezone.utc)
|
||
|
||
for cred in credentials:
|
||
current_customer_id = base_customer_id
|
||
if actor.role.name == RoleName.ADMIN.value and not current_customer_id:
|
||
mappings = (
|
||
await session.scalars(
|
||
select(CustomerCredential.customer_id)
|
||
.where(CustomerCredential.credential_id == cred.id)
|
||
.where(CustomerCredential.is_allowed == 1)
|
||
)
|
||
).all()
|
||
uniq = sorted({cid for cid in mappings if cid})
|
||
if len(uniq) == 1:
|
||
current_customer_id = uniq[0]
|
||
else:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"customer_id required for credential {cred.id}",
|
||
)
|
||
if not current_customer_id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="customer_id required",
|
||
)
|
||
|
||
target_region = region or cred.default_region
|
||
try:
|
||
resp = await asyncio.to_thread(aws_ops.describe_instances, cred, target_region)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc))
|
||
|
||
cloud_instances: dict[str, dict] = {}
|
||
for res in resp.get("Reservations", []):
|
||
for inst in res.get("Instances", []):
|
||
iid = inst.get("InstanceId")
|
||
if iid:
|
||
cloud_instances[iid] = inst
|
||
|
||
status_resp: dict = {}
|
||
if cloud_instances:
|
||
status_resp = await asyncio.to_thread(
|
||
aws_ops.describe_instance_status, cred, target_region, list(cloud_instances.keys())
|
||
)
|
||
status_map: Dict[str, dict] = {}
|
||
for st in status_resp.get("InstanceStatuses", []):
|
||
iid = st.get("InstanceId")
|
||
status_map[iid] = {
|
||
"instance": st.get("InstanceStatus", {}).get("Status"),
|
||
"system": st.get("SystemStatus", {}).get("Status"),
|
||
}
|
||
|
||
local_instances = (
|
||
await session.scalars(
|
||
select(Instance)
|
||
.options(selectinload(Instance.credential), selectinload(Instance.customer))
|
||
.where(
|
||
Instance.credential_id == cred.id,
|
||
Instance.region == target_region,
|
||
Instance.customer_id == current_customer_id,
|
||
)
|
||
)
|
||
).all()
|
||
local_map: Dict[str, Instance] = {inst.instance_id: inst for inst in local_instances}
|
||
|
||
for inst_id, inst in cloud_instances.items():
|
||
state_name = (inst.get("State") or {}).get("Name")
|
||
state_lower = state_name.lower() if state_name else ""
|
||
if state_lower in {"terminated", "shutting-down"}:
|
||
existing = local_map.pop(inst_id, None)
|
||
if existing:
|
||
removed_payloads.append(await _delete_local_instance(session, existing))
|
||
continue
|
||
inst_obj = await upsert_instance_from_cloud(
|
||
session,
|
||
current_customer_id,
|
||
cred,
|
||
target_region,
|
||
inst,
|
||
now,
|
||
status_map.get(inst_id),
|
||
)
|
||
local_map.pop(inst_id, None)
|
||
if inst_obj:
|
||
changed_instances[inst_obj.id] = inst_obj
|
||
synced += 1
|
||
session.add(
|
||
JobItem(
|
||
job_id=job.id,
|
||
resource_type=JobItemResourceType.INSTANCE,
|
||
resource_id=None,
|
||
account_id=cred.account_id,
|
||
region=target_region,
|
||
instance_id=inst_id,
|
||
action=JobItemAction.SYNC,
|
||
status=JobItemStatus.SUCCESS,
|
||
)
|
||
)
|
||
|
||
# 云端不存在的本地实例需要删除
|
||
for orphan in local_map.values():
|
||
removed_payloads.append(await _delete_local_instance(session, orphan))
|
||
|
||
job.total_count = synced
|
||
job.success_count = synced
|
||
job.status = JobStatus.SUCCESS
|
||
job.progress = 100
|
||
job.finished_at = datetime.now(timezone.utc)
|
||
session.add(
|
||
AuditLog(
|
||
user_id=actor.id,
|
||
customer_id=base_customer_id or actor.customer_id,
|
||
action=AuditAction.INSTANCE_SYNC,
|
||
resource_type=AuditResourceType.INSTANCE,
|
||
resource_id=None,
|
||
description=f"Sync instances {synced}",
|
||
payload={"credential_id": credential_id, "region": region},
|
||
)
|
||
)
|
||
await session.commit()
|
||
enriched_instances = list(changed_instances.values())
|
||
await enrich_instances(enriched_instances)
|
||
for inst_obj in enriched_instances:
|
||
await broadcast_instance_update(inst_obj)
|
||
for payload in removed_payloads:
|
||
await broadcast_instance_removed(payload, payload.get("customer_id"))
|
||
return job
|
||
|
||
|
||
async def batch_instances_action(
|
||
session: AsyncSession,
|
||
payload: BatchInstancesActionIn,
|
||
actor: User,
|
||
) -> BatchInstancesActionOut:
|
||
if len(payload.instance_ids) > MAX_BATCH_COUNT:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOO_MANY_ITEMS")
|
||
requested_ids = list(dict.fromkeys(payload.instance_ids))
|
||
if not requested_ids:
|
||
return BatchInstancesActionOut(action=payload.action, requested=[], accepted=[], skipped=[], errors={})
|
||
|
||
instances = (await session.scalars(build_instances_query({"instance_ids": requested_ids}, actor))).all()
|
||
found_map = {inst.id: inst for inst in instances}
|
||
result = BatchInstancesActionOut(action=payload.action, requested=requested_ids, accepted=[], skipped=[], errors={})
|
||
accepted_ids: set[int] = set()
|
||
skipped_ids: set[int] = set()
|
||
for iid in requested_ids:
|
||
if iid not in found_map and iid not in skipped_ids:
|
||
result.skipped.append(iid)
|
||
result.errors[str(iid)] = "NOT_FOUND_OR_FORBIDDEN"
|
||
skipped_ids.add(iid)
|
||
|
||
job_action_map = {
|
||
"start": JobItemAction.START,
|
||
"stop": JobItemAction.STOP,
|
||
"reboot": JobItemAction.REBOOT,
|
||
"terminate": JobItemAction.TERMINATE,
|
||
}
|
||
sync_plan: dict[tuple[int, str, int | None], list[int]] = {}
|
||
|
||
for inst in instances:
|
||
if actor.role.name != RoleName.ADMIN.value and inst.customer_id != actor.customer_id:
|
||
if inst.id not in skipped_ids:
|
||
result.skipped.append(inst.id)
|
||
result.errors[str(inst.id)] = "NOT_FOUND_OR_FORBIDDEN"
|
||
skipped_ids.add(inst.id)
|
||
continue
|
||
if not inst.credential_id:
|
||
if inst.id not in skipped_ids:
|
||
result.skipped.append(inst.id)
|
||
result.errors[str(inst.id)] = "MISSING_CREDENTIAL"
|
||
skipped_ids.add(inst.id)
|
||
continue
|
||
try:
|
||
await ensure_credential_access(session, inst.credential_id, actor)
|
||
except HTTPException:
|
||
if inst.id not in skipped_ids:
|
||
result.skipped.append(inst.id)
|
||
result.errors[str(inst.id)] = "NOT_FOUND_OR_FORBIDDEN"
|
||
skipped_ids.add(inst.id)
|
||
continue
|
||
|
||
if payload.action == "sync":
|
||
key = (inst.credential_id, inst.region, inst.customer_id)
|
||
sync_plan.setdefault(key, []).append(inst.id)
|
||
continue
|
||
|
||
job_action = job_action_map.get(payload.action)
|
||
if not job_action:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported action")
|
||
try:
|
||
await enqueue_action(session, inst, job_action, actor)
|
||
if inst.id not in accepted_ids:
|
||
accepted_ids.add(inst.id)
|
||
result.accepted.append(inst.id)
|
||
except Exception as exc:
|
||
if inst.id not in skipped_ids:
|
||
skipped_ids.add(inst.id)
|
||
result.skipped.append(inst.id)
|
||
result.errors[str(inst.id)] = str(exc)
|
||
|
||
# handle sync actions per credential/region/customer
|
||
for key, inst_ids in sync_plan.items():
|
||
cred_id, region, customer_id = key
|
||
try:
|
||
await sync_instances(session, cred_id, region, actor, customer_id_override=customer_id)
|
||
for iid in inst_ids:
|
||
if iid not in accepted_ids:
|
||
accepted_ids.add(iid)
|
||
result.accepted.append(iid)
|
||
except Exception as exc:
|
||
for iid in inst_ids:
|
||
if iid not in skipped_ids:
|
||
skipped_ids.add(iid)
|
||
result.skipped.append(iid)
|
||
result.errors[str(iid)] = str(exc)
|
||
|
||
return result
|
||
|
||
|
||
async def batch_instances_by_ips(
|
||
session: AsyncSession,
|
||
payload: BatchInstancesByIpIn,
|
||
actor: User,
|
||
) -> dict:
|
||
ips = []
|
||
seen = set()
|
||
for raw in payload.ips:
|
||
for part in re.split(r"[\s,]+", (raw or "").strip()):
|
||
if not part or part in seen:
|
||
continue
|
||
seen.add(part)
|
||
ips.append(part)
|
||
if not ips:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="ips required")
|
||
if len(ips) > MAX_BATCH_COUNT:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOO_MANY_ITEMS")
|
||
ip_set = set(ips)
|
||
filters = {
|
||
"credential_id": payload.credential_id,
|
||
"region": payload.region,
|
||
}
|
||
query = build_instances_query(filters, actor).where(
|
||
or_(Instance.public_ip.in_(ip_set), Instance.private_ip.in_(ip_set))
|
||
)
|
||
instances = (await session.scalars(query)).all()
|
||
matched_ids = [inst.id for inst in instances]
|
||
matched_ips: set[str] = set()
|
||
for inst in instances:
|
||
if inst.public_ip and inst.public_ip in ip_set:
|
||
matched_ips.add(inst.public_ip)
|
||
if inst.private_ip and inst.private_ip in ip_set:
|
||
matched_ips.add(inst.private_ip)
|
||
result = await batch_instances_action(
|
||
session,
|
||
BatchInstancesActionIn(instance_ids=matched_ids, action=payload.action),
|
||
actor,
|
||
)
|
||
return {
|
||
"ips_requested": ips,
|
||
"ips_matched": list(matched_ips),
|
||
"ips_unmatched": [ip for ip in ips if ip not in matched_ips],
|
||
"result": result,
|
||
}
|
||
|
||
|
||
def _fmt(dt: Optional[datetime]) -> str:
|
||
if not dt:
|
||
return ""
|
||
if dt.tzinfo:
|
||
return dt.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
|
||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
|
||
def _fmt_gib(val: float | None) -> str:
|
||
if val is None:
|
||
return ""
|
||
try:
|
||
as_float = float(val)
|
||
except (TypeError, ValueError):
|
||
return str(val)
|
||
if as_float.is_integer():
|
||
return str(int(as_float))
|
||
return f"{as_float:.2f}".rstrip("0").rstrip(".")
|
||
|
||
|
||
def _build_export_sheet(instances: List[Instance]) -> bytes:
|
||
wb = Workbook()
|
||
ws = wb.active
|
||
ws.title = "Instances"
|
||
headers = [
|
||
"Name",
|
||
"Credential",
|
||
"Owner",
|
||
"Instance ID",
|
||
"Type",
|
||
"Instance Type Specs",
|
||
"AMI / OS",
|
||
"OS Pretty",
|
||
"Public IP",
|
||
"Private IP",
|
||
"Region",
|
||
"AZ",
|
||
"Status",
|
||
"Desired Status",
|
||
"Account ID",
|
||
"Credential ID",
|
||
"Security Groups",
|
||
"Subnet ID",
|
||
"VPC ID",
|
||
"Launched At",
|
||
"Created At",
|
||
"Last Sync",
|
||
]
|
||
ws.append(headers)
|
||
for inst in instances:
|
||
os_pretty = inst.os_pretty_name or _derive_os_pretty_name(inst)
|
||
os_label = os_pretty or inst.os_name
|
||
if not os_label and isinstance(inst.last_cloud_state, dict):
|
||
os_label = inst.last_cloud_state.get("os_family") or inst.last_cloud_state.get("ami_name")
|
||
ami_os = " / ".join([x for x in [os_label, inst.ami_id] if x])
|
||
sg_val = ""
|
||
if isinstance(inst.security_groups, list):
|
||
sg_val = ", ".join([str(sg) for sg in inst.security_groups])
|
||
elif isinstance(inst.security_groups, dict):
|
||
sg_val = ", ".join([str(v) for v in inst.security_groups.values()])
|
||
spec_parts = []
|
||
if inst.instance_vcpus is not None:
|
||
spec_parts.append(f"{inst.instance_vcpus} vCPU")
|
||
if inst.instance_memory_gib is not None:
|
||
spec_parts.append(f"{_fmt_gib(inst.instance_memory_gib)} GiB")
|
||
if inst.instance_network_perf:
|
||
spec_parts.append(inst.instance_network_perf)
|
||
type_specs = inst.instance_type
|
||
if spec_parts:
|
||
type_specs = f"{inst.instance_type} ({', '.join(spec_parts)})" if inst.instance_type else ", ".join(spec_parts)
|
||
ws.append(
|
||
[
|
||
inst.name_tag or "",
|
||
getattr(inst, "credential_label", None)
|
||
or getattr(inst, "credential_name", None)
|
||
or (inst.credential.name if getattr(inst, "credential", None) else ""),
|
||
getattr(inst, "owner_name", None) or getattr(inst, "customer_name", None) or "",
|
||
inst.instance_id,
|
||
inst.instance_type,
|
||
type_specs,
|
||
ami_os,
|
||
os_pretty or "",
|
||
inst.public_ip or "",
|
||
inst.private_ip or "",
|
||
inst.region,
|
||
inst.az or "",
|
||
inst.status.value if isinstance(inst.status, InstanceStatus) else str(inst.status),
|
||
inst.desired_status.value if isinstance(inst.desired_status, InstanceDesiredStatus) else inst.desired_status or "",
|
||
inst.account_id,
|
||
inst.credential_id or "",
|
||
sg_val,
|
||
inst.subnet_id or "",
|
||
inst.vpc_id or "",
|
||
_fmt(inst.launched_at),
|
||
_fmt(inst.created_at),
|
||
_fmt(inst.last_sync),
|
||
]
|
||
)
|
||
buffer = BytesIO()
|
||
wb.save(buffer)
|
||
buffer.seek(0)
|
||
return buffer.getvalue()
|
||
|
||
|
||
async def export_instances(
|
||
session: AsyncSession,
|
||
filters: dict,
|
||
actor: User,
|
||
) -> bytes:
|
||
filters = dict(filters or {})
|
||
filters.pop("offset", None)
|
||
filters.pop("limit", None)
|
||
query = build_instances_query(filters, actor).order_by(Instance.updated_at.desc())
|
||
rows = (await session.scalars(query)).all()
|
||
if len(rows) > MAX_EXPORT_ROWS:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"EXPORT_LIMIT_EXCEEDED:{MAX_EXPORT_ROWS}"
|
||
)
|
||
await enrich_instances(rows)
|
||
return _build_export_sheet(rows)
|
||
|
||
|
||
async def export_instances_by_ids(
|
||
session: AsyncSession,
|
||
instance_ids: List[int],
|
||
actor: User,
|
||
) -> bytes:
|
||
if len(instance_ids) > MAX_EXPORT_IDS:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOO_MANY_ITEMS")
|
||
ids = list(dict.fromkeys(instance_ids))
|
||
if not ids:
|
||
return _build_export_sheet([])
|
||
query = build_instances_query({}, actor).where(Instance.id.in_(ids)).order_by(Instance.updated_at.desc())
|
||
rows = (await session.scalars(query)).all()
|
||
if len(rows) > MAX_EXPORT_ROWS:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"EXPORT_LIMIT_EXCEEDED:{MAX_EXPORT_ROWS}"
|
||
)
|
||
await enrich_instances(rows)
|
||
return _build_export_sheet(rows)
|