2025-12-10 12:02:17 +08:00

1443 lines
57 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
import logging
import re
import secrets
import string
import time
from datetime import datetime, timezone
from io import BytesIO
from typing import Any, Dict, List, Optional
from uuid import uuid4
from fastapi import HTTPException, status
from sqlalchemy import and_, func, or_, select
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from botocore.exceptions import ClientError
from openpyxl import Workbook
from backend.db.session import AsyncSessionLocal
from backend.modules.audit.models import AuditAction, AuditLog, AuditResourceType
from backend.modules.aws_accounts.models import AWSCredential, CredentialType, CustomerCredential
from backend.modules.instances import aws_ops
from backend.modules.instances.models import Instance, InstanceDesiredStatus, InstanceStatus
from backend.modules.instances.constants import DEFAULT_LOGIN_PASSWORD, DEFAULT_LOGIN_USERNAME
from backend.modules.instances.bootstrap_templates import build_user_data
from backend.modules.instances.events import (
broadcast_instance_removed,
broadcast_instance_update,
build_removed_payload,
)
from backend.modules.jobs.models import (
Job,
JobItem,
JobItemAction,
JobItemResourceType,
JobItemStatus,
JobStatus,
JobType,
)
from backend.modules.customers.models import Customer
from backend.modules.users.models import Role, RoleName, User
from backend.core.config import settings
from backend.modules.instances.schemas import (
BatchInstancesActionIn,
BatchInstancesActionOut,
BatchInstancesByIpIn,
)
STATE_MAP: Dict[str, InstanceStatus] = {
"pending": InstanceStatus.PENDING,
"running": InstanceStatus.RUNNING,
"stopping": InstanceStatus.STOPPING,
"stopped": InstanceStatus.STOPPED,
"shutting-down": InstanceStatus.SHUTTING_DOWN,
"terminated": InstanceStatus.TERMINATED,
}
JOB_TYPE_BY_ACTION = {
JobItemAction.START: JobType.START_INSTANCES,
JobItemAction.STOP: JobType.STOP_INSTANCES,
JobItemAction.REBOOT: JobType.REBOOT_INSTANCES,
JobItemAction.TERMINATE: JobType.TERMINATE_INSTANCES,
}
DESIRED_BY_ACTION = {
JobItemAction.START: InstanceDesiredStatus.RUNNING,
JobItemAction.STOP: InstanceDesiredStatus.STOPPED,
JobItemAction.REBOOT: InstanceDesiredStatus.RUNNING,
JobItemAction.TERMINATE: InstanceDesiredStatus.TERMINATED,
}
# cache image name lookups to avoid repeated describe_images calls within a process
_image_name_cache: dict[tuple[int, str, str], str | None] = {}
_instance_type_specs_cache: dict[tuple[str, int, str], dict[str, tuple[int | None, float | None, str | None]]] = {}
_instance_type_cache_ts: dict[tuple[str, int, str], float] = {}
INSTANCE_TYPE_CACHE_TTL = 60 * 60 # seconds
MAX_BATCH_COUNT = 200
MAX_EXPORT_ROWS = 5000
MAX_EXPORT_IDS = 2000
MIN_SCOPE_SYNC_INTERVAL = max(1, int(settings.scope_sync_min_interval_seconds))
GLOBAL_SYNC_INTERVAL_MINUTES = max(1, int(settings.global_sync_interval_minutes))
GLOBAL_SYNC_MAX_CONCURRENCY = max(1, int(settings.global_sync_max_concurrency))
logger = logging.getLogger(__name__)
_scope_sync_last: dict[tuple[str, int, str], float] = {}
_scope_sync_locks: dict[tuple[str, int, str], asyncio.Lock] = {}
_global_sync_task: Optional[asyncio.Task] = None
_scope_sync_semaphore = asyncio.Semaphore(GLOBAL_SYNC_MAX_CONCURRENCY)
async def _delete_local_instance(session: AsyncSession, instance: Instance) -> dict:
payload = build_removed_payload(instance)
await session.delete(instance)
return payload
def _scope_key(provider: str, credential_id: Optional[int], region: Optional[str]) -> tuple[str, int, str]:
return (provider, credential_id or 0, region or "*")
def _get_scope_lock(provider: str, credential_id: Optional[int], region: Optional[str]) -> asyncio.Lock:
key = _scope_key(provider, credential_id, region)
lock = _scope_sync_locks.get(key)
if lock is None:
lock = asyncio.Lock()
_scope_sync_locks[key] = lock
return lock
async def _pick_sync_actor(session: AsyncSession, actor_id: Optional[int]) -> Optional[User]:
if actor_id:
user = await session.get(User, actor_id)
if user:
return user
admin = await session.scalar(
select(User).join(Role).where(Role.name == RoleName.ADMIN.value).limit(1)
)
if admin:
return admin
return await session.scalar(select(User).limit(1))
async def enqueue_scope_sync(provider: str, credential_id: int, region: str, *, reason: str, actor_id: Optional[int] = None, customer_id: Optional[int] = None) -> bool:
"""
Enqueue a scope sync job with simple rate limiting to avoid hammering cloud APIs.
Returns True if enqueued, False if skipped due to debounce.
"""
key = _scope_key(provider, credential_id, region)
now = time.monotonic()
last = _scope_sync_last.get(key, 0)
if now - last < MIN_SCOPE_SYNC_INTERVAL:
return False
_scope_sync_last[key] = now
async def _runner():
async with _scope_sync_semaphore:
try:
async with AsyncSessionLocal() as session:
actor = await _pick_sync_actor(session, actor_id)
if not actor:
logger.warning("scope sync skipped due to missing actor %s/%s/%s", provider, credential_id, region)
return
await sync_instances(
session,
credential_id,
region,
actor,
customer_id_override=customer_id,
provider=provider,
)
except Exception as exc: # pragma: no cover - best-effort sync
logger.warning("scope sync failed for %s/%s/%s: %s", provider, credential_id, region, exc)
asyncio.create_task(_runner())
logger.info("scope sync enqueued for %s/%s/%s reason=%s", provider, credential_id, region, reason)
return True
async def _discover_sync_scopes(session: AsyncSession) -> list[tuple[int, str, Optional[int]]]:
scopes: set[tuple[int, str, Optional[int]]] = set()
rows = await session.execute(
select(Instance.credential_id, Instance.region, Instance.customer_id)
.where(Instance.credential_id.isnot(None))
.distinct()
)
for cred_id, region, customer_id in rows.all():
if cred_id and region:
scopes.add((cred_id, region, customer_id))
credentials = (await session.scalars(select(AWSCredential).where(AWSCredential.is_active == 1))).all()
for cred in credentials:
mappings = (
await session.scalars(
select(CustomerCredential.customer_id)
.where(CustomerCredential.credential_id == cred.id)
.where(CustomerCredential.is_allowed == 1)
)
).all()
uniq = sorted({cid for cid in mappings if cid})
if len(uniq) == 1:
scopes.add((cred.id, cred.default_region, uniq[0]))
return list(scopes)
async def _global_sync_loop() -> None:
await asyncio.sleep(5)
while True:
try:
async with AsyncSessionLocal() as session:
admin = await _pick_sync_actor(session, None)
admin_id = admin.id if admin else None
scopes = await _discover_sync_scopes(session)
for cred_id, region, customer_id in scopes:
await enqueue_scope_sync(
"aws",
cred_id,
region,
reason="global_cron",
actor_id=admin_id,
customer_id=customer_id,
)
except Exception as exc: # pragma: no cover - background loop resilience
logger.warning("global sync loop error: %s", exc)
await asyncio.sleep(GLOBAL_SYNC_INTERVAL_MINUTES * 60)
def start_global_sync_scheduler() -> None:
global _global_sync_task
if GLOBAL_SYNC_INTERVAL_MINUTES <= 0:
return
if _global_sync_task and not _global_sync_task.done():
return
_global_sync_task = asyncio.create_task(_global_sync_loop())
def _pretty_os_label(raw: str | None) -> str | None:
if not raw:
return None
text = str(raw).strip()
lower = text.lower()
def find_version(patterns: list[str]) -> str | None:
for pattern in patterns:
m = re.search(pattern, lower)
if m:
return m.group(1)
return None
if "ubuntu" in lower:
version = find_version([r"ubuntu[\w\- ]*(\d{2}\.\d{2})"])
codename_map = {"jammy": "22.04", "focal": "20.04", "bionic": "18.04"}
for code, ver in codename_map.items():
if code in lower:
version = version or ver
label = "Ubuntu"
if version:
label = f"Ubuntu {version}"
if version in {"22.04", "20.04", "18.04"}:
label = f"{label} LTS"
return label
if "debian" in lower:
version = find_version([r"debian[\w\- ]*(\d{1,2})", r"debian[\w\- ]*(\d{1,2}\.\d)"])
codename_map = {"bookworm": "12", "bullseye": "11"}
for code, ver in codename_map.items():
if code in lower:
version = version or ver
return f"Debian {version}" if version else "Debian"
if "centos" in lower:
version = find_version([r"centos[\w\- ]*(\d{1,2})"])
return f"CentOS {version}" if version else "CentOS"
if "amazon linux" in lower or lower.startswith("amzn") or "amzn" in lower:
version = find_version([r"amazon linux ?(\d+)", r"amzn(\d)"])
return f"Amazon Linux {version}" if version else "Amazon Linux"
if "windows" in lower:
version = find_version([r"windows[ _\-]*server[ _\-]*(\d{4})", r"windows[ _\-]*(\d{4})"])
if not version:
for candidate in ["2022", "2019", "2016", "2012"]:
if candidate in lower:
version = candidate
break
base = "Windows Server" if "server" in lower or version else "Windows"
return f"{base} {version}".strip() if version else base
return text
def _chunked(seq: List[str], size: int = 100) -> List[List[str]]:
return [seq[i : i + size] for i in range(0, len(seq), size)]
async def get_instance_type_specs(
provider: str,
credential: AWSCredential | None,
region: str,
instance_types: List[str],
) -> dict[str, tuple[int | None, float | None, str | None]]:
if provider != "aws" or not credential or not instance_types:
return {}
region_use = region or credential.default_region
if not region_use:
return {}
cache_key = (provider, credential.id, region_use)
now = time.time()
cache_ts = _instance_type_cache_ts.get(cache_key, 0)
cache_items = _instance_type_specs_cache.get(cache_key, {})
if now - cache_ts > INSTANCE_TYPE_CACHE_TTL:
cache_items = {}
needed = [it for it in set(instance_types) if it and it not in cache_items]
if needed:
fetched = dict(cache_items)
for chunk in _chunked(sorted(needed), 100):
try:
resp = await asyncio.to_thread(
aws_ops.describe_instance_types,
credential,
region_use,
[{"Name": "instance-type", "Values": chunk}],
)
except Exception as exc: # pragma: no cover - best effort metadata
logger.debug("describe_instance_types failed for %s/%s: %s", credential.id, region_use, exc)
continue
for item in resp:
itype = item.get("InstanceType")
if not itype:
continue
vcpu = (item.get("VCpuInfo") or {}).get("DefaultVCpus")
mem_mib = (item.get("MemoryInfo") or {}).get("SizeInMiB")
mem_gib = round(mem_mib / 1024, 2) if mem_mib is not None else None
net_perf = (item.get("NetworkInfo") or {}).get("NetworkPerformance")
fetched[itype] = (vcpu, mem_gib, net_perf)
cache_items = fetched
_instance_type_specs_cache[cache_key] = cache_items
_instance_type_cache_ts[cache_key] = now
return {it: cache_items[it] for it in instance_types if it in cache_items}
async def _lookup_image_name(cred: AWSCredential, region: str, ami_id: str) -> str | None:
cache_key = (cred.id, region, ami_id)
if cache_key in _image_name_cache:
return _image_name_cache[cache_key]
name: str | None = None
try:
resp = await asyncio.to_thread(aws_ops.describe_images, cred, region, [ami_id])
images = resp.get("Images", [])
if images:
img = images[0]
name = img.get("Name") or img.get("Description")
except Exception:
name = None
_image_name_cache[cache_key] = name
return name
async def _resolve_os_name(
cred: AWSCredential, region: str, inst: dict, existing_os: str | None = None
) -> tuple[str | None, str | None]:
"""
Try to get a user-friendly OS name. Prefer existing value; otherwise use AWS metadata and AMI name.
Returns a tuple of (pretty_name, image_name).
"""
os_hint = existing_os or inst.get("PlatformDetails") or inst.get("Platform")
ami_id = inst.get("ImageId")
os_name = os_hint
image_name: str | None = None
lower_name = os_name.lower() if isinstance(os_name, str) else ""
needs_lookup = not os_name or lower_name.startswith("linux/unix") or not re.search(r"\d", lower_name)
if ami_id and needs_lookup:
image_name = await _lookup_image_name(cred, region, ami_id)
if image_name:
os_name = image_name
if not os_name:
os_name = ami_id
pretty = _pretty_os_label(os_name)
if not pretty and existing_os:
pretty = _pretty_os_label(existing_os)
fallback = os_name or existing_os
return pretty or fallback, image_name
def _derive_os_pretty_name(inst: Instance) -> str | None:
candidates: list[str | None] = []
candidates.append(inst.os_name)
if isinstance(inst.last_cloud_state, dict):
candidates.append(inst.last_cloud_state.get("os_family"))
candidates.append(inst.last_cloud_state.get("ami_name"))
for cand in candidates:
pretty = _pretty_os_label(cand)
if pretty:
return pretty
for cand in candidates:
if cand:
return str(cand)
return None
async def enrich_instances(instances: List[Instance]) -> None:
if not instances:
return
# OS pretty names first so they are present even if specs lookup fails
for inst in instances:
inst.os_pretty_name = _derive_os_pretty_name(inst)
scope_map: dict[tuple[int, str], list[Instance]] = {}
cred_map: dict[int, AWSCredential] = {}
for inst in instances:
if inst.credential and inst.credential_id:
cred_map[inst.credential_id] = inst.credential
if inst.credential_id and inst.region and inst.instance_type:
scope_map.setdefault((inst.credential_id, inst.region), []).append(inst)
for (cred_id, region), scoped_instances in scope_map.items():
cred = cred_map.get(cred_id)
if not cred:
continue
types = [i.instance_type for i in scoped_instances if i.instance_type]
specs_map = await get_instance_type_specs("aws", cred, region, types)
if not specs_map:
continue
for inst in scoped_instances:
spec = specs_map.get(inst.instance_type)
if not spec:
continue
vcpu, mem_gib, net_perf = spec
inst.instance_vcpus = vcpu
inst.instance_memory_gib = mem_gib
inst.instance_network_perf = net_perf
def _map_state(state: Optional[str]) -> InstanceStatus:
if not state:
return InstanceStatus.UNKNOWN
return STATE_MAP.get(state.lower(), InstanceStatus.UNKNOWN)
def _extract_name_tag(tags: Optional[List[Dict[str, str]]]) -> Optional[str]:
if not tags:
return None
for tag in tags:
if tag.get("Key") == "Name":
return tag.get("Value")
return None
async def ensure_credential_access(session: AsyncSession, credential_id: int, actor: User) -> AWSCredential:
cred = await session.get(AWSCredential, credential_id)
if not cred or not cred.is_active:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found")
if actor.role.name == RoleName.ADMIN.value:
return cred
mapping = await session.scalar(
select(CustomerCredential).where(
and_(
CustomerCredential.customer_id == actor.customer_id,
CustomerCredential.credential_id == credential_id,
CustomerCredential.is_allowed == 1,
)
)
)
if not mapping:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Credential not allowed")
return cred
def build_instances_query(filters: dict, actor: User):
filters = dict(filters or {})
query = select(Instance).options(
selectinload(Instance.credential),
selectinload(Instance.customer),
)
if actor.role.name != RoleName.ADMIN.value:
query = query.where(Instance.customer_id == actor.customer_id)
elif filters.get("customer_id"):
query = query.where(Instance.customer_id == filters["customer_id"])
if filters.get("credential_id"):
query = query.where(Instance.credential_id == filters["credential_id"])
if filters.get("account_id"):
query = query.where(Instance.account_id == filters["account_id"])
if filters.get("region"):
query = query.where(Instance.region == filters["region"])
if filters.get("status"):
query = query.where(Instance.status == filters["status"])
if filters.get("instance_ids"):
ids = list(dict.fromkeys(filters["instance_ids"]))
if ids:
query = query.where(Instance.id.in_(ids))
if filters.get("keyword"):
pattern = f"%{filters['keyword']}%"
query = query.where(
or_(
Instance.name_tag.ilike(pattern),
Instance.instance_id.ilike(pattern),
Instance.public_ip.ilike(pattern),
Instance.private_ip.ilike(pattern),
)
)
return query
async def list_instances(
session: AsyncSession,
filters: dict,
actor: User,
) -> tuple[List[Instance], int]:
query = build_instances_query(filters, actor)
total = await session.scalar(select(func.count()).select_from(query.subquery()))
rows = (
await session.scalars(
query.order_by(Instance.updated_at.desc()).offset(filters.get("offset", 0)).limit(filters.get("limit", 20))
)
).all()
await enrich_instances(rows)
return rows, total or 0
def _generate_random_password(length: int = 16) -> str:
alphabet = string.ascii_letters + string.digits
return "".join(secrets.choice(alphabet) for _ in range(length))
def _resolve_default_ami(cred: AWSCredential, region: str, os_family: str) -> str:
"""Fetch latest common AMI for a given OS family/version. Best-effort."""
session, cfg = aws_ops.build_session(cred, region)
client = session.client("ec2", region_name=region or cred.default_region, config=cfg)
family = (os_family or "ubuntu-22.04").lower()
owners: List[str] = []
name_values: List[str] = []
if family.startswith("ubuntu-22.04"):
owners, name_values = (["099720109477"], ["ubuntu/images/hvm-ssd/ubuntu-jammy-22.04-amd64-server-*"])
elif family.startswith("ubuntu-20.04"):
owners, name_values = (["099720109477"], ["ubuntu/images/hvm-ssd/ubuntu-focal-20.04-amd64-server-*"])
elif family.startswith("debian-12"):
owners, name_values = (["136693071363"], ["debian-12-amd64-*"])
elif family.startswith("debian-11"):
owners, name_values = (["136693071363"], ["debian-11-amd64-*"])
elif family.startswith("centos-9"):
owners, name_values = (["125523088429"], ["CentOS Stream 9*"])
elif family.startswith("amazonlinux-2"):
owners, name_values = (["137112412989"], ["amzn2-ami-hvm-*-x86_64-gp2"])
elif family.startswith("windows-2019"):
owners, name_values = (["801119661308"], ["Windows_Server-2019-English-Full-Base-*"])
elif family.startswith("windows-2022"):
owners, name_values = (["801119661308"], ["Windows_Server-2022-English-Full-Base-*"])
else:
owners, name_values = (["099720109477"], ["ubuntu/images/hvm-ssd/ubuntu-jammy-22.04-amd64-server-*"])
try:
resp = client.describe_images(
Owners=owners,
Filters=[{"Name": "name", "Values": name_values}, {"Name": "architecture", "Values": ["x86_64"]}],
)
images = sorted(resp.get("Images", []), key=lambda x: x.get("CreationDate", ""), reverse=True)
if images:
return images[0].get("ImageId")
except ClientError as exc:
code = (exc.response.get("Error") or {}).get("Code")
msg = (exc.response.get("Error") or {}).get("Message")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"获取镜像失败: {code} {msg}")
except Exception as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"获取镜像失败: {exc}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="无法为所选系统找到可用 AMI请手动填写")
def _get_root_device_name(cred: AWSCredential, region: str, ami_id: str) -> str:
session, cfg = aws_ops.build_session(cred, region)
client = session.client("ec2", region_name=region or cred.default_region, config=cfg)
try:
resp = client.describe_images(ImageIds=[ami_id])
images = resp.get("Images", [])
if images and images[0].get("RootDeviceName"):
return images[0]["RootDeviceName"]
except Exception:
return "/dev/xvda"
return "/dev/xvda"
async def create_instance(
session: AsyncSession,
payload: dict,
actor: User,
) -> dict:
# enforce instance type for customers
allowed_customer_types = {"t3.micro", "t3.small", "t3.medium"}
is_admin = actor.role.name == RoleName.ADMIN.value
if not payload.get("instance_type"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="instance_type required")
if not is_admin and payload.get("instance_type") not in allowed_customer_types:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="客户仅可使用 t3.micro / t3.small / t3.medium 实例类型",
)
count = payload.get("count") or 1
if count < 1:
count = 1
customer_id = payload.get("customer_id") or actor.customer_id
if not customer_id:
# derive from credential mapping if there is exactly one allowed customer
mappings = (
await session.scalars(
select(CustomerCredential.customer_id).where(
CustomerCredential.credential_id == payload["credential_id"],
CustomerCredential.is_allowed == 1,
)
)
).all()
unique_customer_ids = sorted({cid for cid in mappings if cid})
if len(unique_customer_ids) == 1:
customer_id = unique_customer_ids[0]
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="customer_id required")
cred = await ensure_credential_access(session, payload["credential_id"], actor)
if cred.account_id != payload["account_id"]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="account_id mismatch")
region = payload["region"] or cred.default_region
# quota check: customer-level + region active count
customer = await session.get(Customer, customer_id)
if not customer:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Customer not found")
quota_limit = customer.quota_instances if customer.quota_instances is not None else 999999
active_count = await session.scalar(
select(func.count(Instance.id)).where(
Instance.customer_id == customer_id,
Instance.status != InstanceStatus.TERMINATED,
Instance.desired_status != InstanceDesiredStatus.TERMINATED,
Instance.region == region,
)
)
available = max(0, quota_limit - (active_count or 0))
if count > available:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"超出配额,当前区域可用数量 {available},请调整数量或联系管理员扩容",
)
session_boto, cfg = aws_ops.build_session(cred, region)
ec2_client = session_boto.client("ec2", region_name=region or cred.default_region, config=cfg)
login_username = payload.get("login_username") or DEFAULT_LOGIN_USERNAME
login_password = payload.get("login_password") or DEFAULT_LOGIN_PASSWORD
# allow client to request random password by passing "RANDOM" sentinel
if isinstance(login_password, str) and login_password.upper() == "RANDOM":
login_password = _generate_random_password()
user_data_mode = (payload.get("user_data_mode") or "auto_root").lower()
os_family = payload.get("os_family") or "ubuntu"
user_data_content: Optional[str] = None
if user_data_mode == "auto_root":
user_data_content = build_user_data(os_family, login_username, login_password)
elif user_data_mode == "custom":
user_data_content = payload.get("custom_user_data")
else:
user_data_content = None
sg_ids = payload.get("security_group_ids") or []
if not sg_ids:
if not settings.auto_open_sg_enabled:
# 保守策略:要求前端显式选择 SG避免无意暴露
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="security_group_ids required when auto open SG is disabled",
)
if not payload.get("subnet_id"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="subnet_id required to derive VPC")
subnet_resp = await asyncio.to_thread(ec2_client.describe_subnets, SubnetIds=[payload["subnet_id"]])
subnets = subnet_resp.get("Subnets", [])
if not subnets:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid subnet_id")
vpc_id = subnets[0].get("VpcId")
if not vpc_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Subnet missing VPC")
# Risk note: open-all SG is only for testing environments.
sg_id = await asyncio.to_thread(aws_ops.ensure_open_all_sg_for_vpc, ec2_client, vpc_id)
sg_ids = [sg_id]
ami_id = payload.get("ami_id")
if not ami_id:
ami_id = _resolve_default_ami(cred, region, os_family)
volume_type = payload.get("volume_type") or "gp3"
volume_size = payload.get("volume_size") or 20
if volume_size <= 0:
volume_size = 20
if not is_admin:
volume_type = "gp3"
cpu_options: Optional[Dict[str, Any]] = None
if (payload.get("instance_type") or "").lower().startswith("t"):
cpu_options = {"CpuCredits": "standard"}
root_device = _get_root_device_name(cred, region, ami_id)
block_device_mappings = [
{
"DeviceName": root_device,
"Ebs": {
"VolumeSize": volume_size,
"VolumeType": volume_type,
"DeleteOnTermination": True,
},
}
]
try:
resp = await asyncio.to_thread(
aws_ops.run_instances,
cred,
region,
ami_id,
payload["instance_type"],
payload.get("key_name"),
sg_ids,
payload.get("subnet_id"),
block_device_mappings,
cpu_options,
count,
count,
payload.get("name_tag"),
user_data_content,
)
except ClientError as exc: # pragma: no cover
err = exc.response.get("Error", {})
code = err.get("Code")
msg = err.get("Message")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"AWS 错误: {code} {msg}") from exc
except Exception as exc: # pragma: no cover
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
instances = resp.get("Instances") or []
if not instances:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="AWS did not return instances")
instance_ids = [i.get("InstanceId") for i in instances if i.get("InstanceId")]
status_checks_map: Dict[str, dict] = {}
# refresh details to fetch public IP if not immediately available
if instance_ids:
try:
describe_resp = await asyncio.to_thread(
aws_ops.describe_instances, cred, region, instance_ids=instance_ids
)
reservations = describe_resp.get("Reservations", [])
tmp = []
for res in reservations:
tmp.extend(res.get("Instances", []))
if tmp:
instances = tmp
status_resp = await asyncio.to_thread(aws_ops.describe_instance_status, cred, region, instance_ids)
for st in status_resp.get("InstanceStatuses", []):
iid = st.get("InstanceId")
status_checks_map[iid] = {
"instance": st.get("InstanceStatus", {}).get("Status"),
"system": st.get("SystemStatus", {}).get("Status"),
}
except Exception:
pass
created_db_instances: List[Instance] = []
now = datetime.now(timezone.utc)
for data in instances:
name_tag = payload.get("name_tag") or _extract_name_tag(data.get("Tags"))
security_groups = [sg.get("GroupId") for sg in data.get("SecurityGroups", []) if sg.get("GroupId")]
db_inst = Instance(
customer_id=customer_id,
credential_id=cred.id,
account_id=cred.account_id,
region=region,
az=(data.get("Placement") or {}).get("AvailabilityZone"),
instance_id=data.get("InstanceId"),
name_tag=name_tag,
instance_type=payload["instance_type"],
ami_id=ami_id,
os_name=os_family,
key_name=payload.get("key_name"),
public_ip=data.get("PublicIpAddress"),
private_ip=data.get("PrivateIpAddress"),
status=_map_state((data.get("State") or {}).get("Name")),
desired_status=InstanceDesiredStatus.RUNNING,
security_groups=security_groups,
subnet_id=data.get("SubnetId"),
vpc_id=data.get("VpcId"),
launched_at=data.get("LaunchTime") or now,
last_sync=now,
last_cloud_state={
"state": data.get("State"),
"tags": data.get("Tags"),
"os_family": os_family,
"status_checks": status_checks_map.get(data.get("InstanceId")),
},
)
session.add(db_inst)
created_db_instances.append(db_inst)
await session.commit()
for db_inst in created_db_instances:
await session.refresh(db_inst)
await session.refresh(db_inst, attribute_names=["credential", "customer"])
db_inst = created_db_instances[0]
session.add(
AuditLog(
user_id=actor.id,
customer_id=customer_id,
action=AuditAction.INSTANCE_CREATE,
resource_type=AuditResourceType.INSTANCE,
resource_id=db_inst.id,
description=f"Create instance {db_inst.instance_id} x{len(created_db_instances)}",
payload=payload,
)
)
await session.commit()
await enrich_instances(created_db_instances)
for inst in created_db_instances:
await broadcast_instance_update(inst)
# 新实例创建后触发 scope 同步,确保云端状态与本地收敛
try:
await enqueue_scope_sync(
"aws",
cred.id,
region,
reason="create_instance",
actor_id=actor.id,
customer_id=customer_id,
)
except Exception as exc: # pragma: no cover
logger.warning("enqueue scope sync after create failed: %s", exc)
return {"instance": db_inst, "login_username": login_username, "login_password": login_password}
async def _process_action(job_id: int, job_item_id: int, action: JobItemAction) -> None:
async with AsyncSessionLocal() as session:
job = await session.get(Job, job_id)
job_item = await session.scalar(
select(JobItem)
.where(JobItem.id == job_item_id)
.options(
selectinload(JobItem.instance).selectinload(Instance.credential),
selectinload(JobItem.instance).selectinload(Instance.customer),
)
)
if not job or not job_item or not job_item.instance or not job_item.instance.credential:
return
job.status = JobStatus.RUNNING
job.started_at = datetime.now(timezone.utc)
job_item.status = JobItemStatus.RUNNING
await session.commit()
cred = job_item.instance.credential
region = job_item.region or job_item.instance.region
instance_id = job_item.instance.instance_id
removed_payload = None
updated_instance: Optional[Instance] = None
try:
if action == JobItemAction.START:
resp = await asyncio.to_thread(aws_ops.start_instances, cred, region, [instance_id])
job_item.instance.status = InstanceStatus.RUNNING
elif action == JobItemAction.STOP:
resp = await asyncio.to_thread(aws_ops.stop_instances, cred, region, [instance_id])
job_item.instance.status = InstanceStatus.STOPPED
elif action == JobItemAction.REBOOT:
resp = await asyncio.to_thread(aws_ops.reboot_instances, cred, region, [instance_id])
job_item.instance.status = InstanceStatus.RUNNING
elif action == JobItemAction.TERMINATE:
resp = await asyncio.to_thread(aws_ops.terminate_instances, cred, region, [instance_id])
job_item.instance.status = InstanceStatus.TERMINATED
job_item.instance.terminated_at = datetime.now(timezone.utc)
removed_payload = await _delete_local_instance(session, job_item.instance)
job_item.resource_id = None
else: # pragma: no cover
resp = {}
if action != JobItemAction.TERMINATE:
updated_instance = job_item.instance
job_item.extra = resp
job_item.status = JobItemStatus.SUCCESS
job.success_count = 1
job.total_count = 1
job.progress = 100
job.status = JobStatus.SUCCESS
job.finished_at = datetime.now(timezone.utc)
await session.commit()
if updated_instance:
await session.refresh(updated_instance)
await broadcast_instance_update(updated_instance)
if removed_payload:
await broadcast_instance_removed(removed_payload, removed_payload.get("customer_id"))
try:
scope_customer = (
job_item.instance.customer_id
if job_item.instance
else removed_payload.get("customer_id") if removed_payload else None
)
await enqueue_scope_sync(
"aws",
cred.id,
region,
reason=f"job_{action.value}",
actor_id=job.created_by_user_id,
customer_id=scope_customer,
)
except Exception as exc: # pragma: no cover - do not fail job if sync enqueue fails
logger.warning("enqueue scope sync after job failed: %s", exc)
except Exception as exc: # pragma: no cover
job_item.status = JobItemStatus.FAILED
job_item.error_message = str(exc)
job.status = JobStatus.FAILED
job.error_message = str(exc)
job.fail_count = 1
job.progress = 100
job.finished_at = datetime.now(timezone.utc)
await session.commit()
async def enqueue_action(session: AsyncSession, instance: Instance, action: JobItemAction, actor: User) -> Job:
job = Job(
job_uuid=uuid4().hex,
job_type=JOB_TYPE_BY_ACTION[action],
status=JobStatus.PENDING,
progress=0,
total_count=1,
created_by_user_id=actor.id,
created_for_customer=instance.customer_id,
)
session.add(job)
await session.flush()
job_item = JobItem(
job_id=job.id,
resource_type=JobItemResourceType.INSTANCE,
resource_id=instance.id,
account_id=instance.account_id,
region=instance.region,
instance_id=instance.instance_id,
action=action,
status=JobItemStatus.PENDING,
)
instance.desired_status = DESIRED_BY_ACTION[action]
session.add(job_item)
await session.commit()
await session.refresh(job)
await session.refresh(instance)
await broadcast_instance_update(instance)
asyncio.create_task(_process_action(job.id, job_item.id, action))
return job
async def upsert_instance_from_cloud(
session: AsyncSession,
customer_id: int,
cred: AWSCredential,
region: str,
inst: dict,
now: datetime,
status_checks: Optional[dict] = None,
) -> Optional[Instance]:
instance_id = inst.get("InstanceId")
state_name = (inst.get("State") or {}).get("Name")
name_tag = _extract_name_tag(inst.get("Tags"))
security_groups = [sg.get("GroupId") for sg in inst.get("SecurityGroups", []) if sg.get("GroupId")]
existing_inst = await session.scalar(
select(Instance).where(
Instance.account_id == cred.account_id,
Instance.region == region,
Instance.instance_id == instance_id,
)
)
os_family, ami_name = await _resolve_os_name(
cred, region, inst, existing_inst.os_name if existing_inst else None
)
record = dict(
customer_id=customer_id,
credential_id=cred.id,
account_id=cred.account_id,
region=region,
az=(inst.get("Placement") or {}).get("AvailabilityZone"),
instance_id=instance_id,
name_tag=name_tag,
instance_type=inst.get("InstanceType"),
ami_id=inst.get("ImageId"),
os_name=os_family,
key_name=inst.get("KeyName"),
public_ip=inst.get("PublicIpAddress"),
private_ip=inst.get("PrivateIpAddress"),
status=_map_state(state_name),
desired_status=None,
security_groups=security_groups,
subnet_id=inst.get("SubnetId"),
vpc_id=inst.get("VpcId"),
launched_at=inst.get("LaunchTime"),
terminated_at=now if state_name == "terminated" else None,
last_sync=now,
last_cloud_state={
"state": inst.get("State"),
"tags": inst.get("Tags"),
"os_family": os_family,
"ami_name": ami_name,
"status_checks": status_checks,
},
)
stmt = insert(Instance).values(**record)
update_cols = {k: stmt.inserted[k] for k in record.keys() if k != "id"}
await session.execute(stmt.on_duplicate_key_update(**update_cols))
return await session.scalar(
select(Instance)
.options(selectinload(Instance.credential), selectinload(Instance.customer))
.where(
Instance.account_id == cred.account_id,
Instance.region == region,
Instance.instance_id == instance_id,
)
)
async def sync_instances(
session: AsyncSession,
credential_id: Optional[int],
region: Optional[str],
actor: User,
customer_id_override: Optional[int] = None,
provider: str = "aws",
) -> Job:
lock = _get_scope_lock(provider, credential_id, region)
async with lock:
base_customer_id = customer_id_override if actor.role.name == RoleName.ADMIN.value else actor.customer_id
credentials_query = select(AWSCredential).where(AWSCredential.is_active == 1)
if actor.role.name != RoleName.ADMIN.value:
credentials_query = (
credentials_query.join(CustomerCredential, CustomerCredential.credential_id == AWSCredential.id)
.where(CustomerCredential.customer_id == actor.customer_id)
.where(CustomerCredential.is_allowed == 1)
)
if credential_id:
credentials_query = credentials_query.where(AWSCredential.id == credential_id)
credentials = (await session.scalars(credentials_query)).all()
if not credentials:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No credentials to sync")
job = Job(
job_uuid=uuid4().hex,
job_type=JobType.SYNC_INSTANCES,
status=JobStatus.RUNNING,
progress=0,
total_count=0,
created_by_user_id=actor.id,
created_for_customer=base_customer_id,
payload={"credential_id": credential_id, "region": region},
started_at=datetime.now(timezone.utc),
)
session.add(job)
await session.commit()
await session.refresh(job)
synced = 0
changed_instances: dict[int, Instance] = {}
removed_payloads: list[dict] = []
now = datetime.now(timezone.utc)
for cred in credentials:
current_customer_id = base_customer_id
if actor.role.name == RoleName.ADMIN.value and not current_customer_id:
mappings = (
await session.scalars(
select(CustomerCredential.customer_id)
.where(CustomerCredential.credential_id == cred.id)
.where(CustomerCredential.is_allowed == 1)
)
).all()
uniq = sorted({cid for cid in mappings if cid})
if len(uniq) == 1:
current_customer_id = uniq[0]
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"customer_id required for credential {cred.id}",
)
if not current_customer_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="customer_id required",
)
target_region = region or cred.default_region
try:
resp = await asyncio.to_thread(aws_ops.describe_instances, cred, target_region)
except Exception as exc:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc))
cloud_instances: dict[str, dict] = {}
for res in resp.get("Reservations", []):
for inst in res.get("Instances", []):
iid = inst.get("InstanceId")
if iid:
cloud_instances[iid] = inst
status_resp: dict = {}
if cloud_instances:
status_resp = await asyncio.to_thread(
aws_ops.describe_instance_status, cred, target_region, list(cloud_instances.keys())
)
status_map: Dict[str, dict] = {}
for st in status_resp.get("InstanceStatuses", []):
iid = st.get("InstanceId")
status_map[iid] = {
"instance": st.get("InstanceStatus", {}).get("Status"),
"system": st.get("SystemStatus", {}).get("Status"),
}
local_instances = (
await session.scalars(
select(Instance)
.options(selectinload(Instance.credential), selectinload(Instance.customer))
.where(
Instance.credential_id == cred.id,
Instance.region == target_region,
Instance.customer_id == current_customer_id,
)
)
).all()
local_map: Dict[str, Instance] = {inst.instance_id: inst for inst in local_instances}
for inst_id, inst in cloud_instances.items():
state_name = (inst.get("State") or {}).get("Name")
state_lower = state_name.lower() if state_name else ""
if state_lower in {"terminated", "shutting-down"}:
existing = local_map.pop(inst_id, None)
if existing:
removed_payloads.append(await _delete_local_instance(session, existing))
continue
inst_obj = await upsert_instance_from_cloud(
session,
current_customer_id,
cred,
target_region,
inst,
now,
status_map.get(inst_id),
)
local_map.pop(inst_id, None)
if inst_obj:
changed_instances[inst_obj.id] = inst_obj
synced += 1
session.add(
JobItem(
job_id=job.id,
resource_type=JobItemResourceType.INSTANCE,
resource_id=None,
account_id=cred.account_id,
region=target_region,
instance_id=inst_id,
action=JobItemAction.SYNC,
status=JobItemStatus.SUCCESS,
)
)
# 云端不存在的本地实例需要删除
for orphan in local_map.values():
removed_payloads.append(await _delete_local_instance(session, orphan))
job.total_count = synced
job.success_count = synced
job.status = JobStatus.SUCCESS
job.progress = 100
job.finished_at = datetime.now(timezone.utc)
session.add(
AuditLog(
user_id=actor.id,
customer_id=base_customer_id or actor.customer_id,
action=AuditAction.INSTANCE_SYNC,
resource_type=AuditResourceType.INSTANCE,
resource_id=None,
description=f"Sync instances {synced}",
payload={"credential_id": credential_id, "region": region},
)
)
await session.commit()
enriched_instances = list(changed_instances.values())
await enrich_instances(enriched_instances)
for inst_obj in enriched_instances:
await broadcast_instance_update(inst_obj)
for payload in removed_payloads:
await broadcast_instance_removed(payload, payload.get("customer_id"))
return job
async def batch_instances_action(
session: AsyncSession,
payload: BatchInstancesActionIn,
actor: User,
) -> BatchInstancesActionOut:
if len(payload.instance_ids) > MAX_BATCH_COUNT:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOO_MANY_ITEMS")
requested_ids = list(dict.fromkeys(payload.instance_ids))
if not requested_ids:
return BatchInstancesActionOut(action=payload.action, requested=[], accepted=[], skipped=[], errors={})
instances = (await session.scalars(build_instances_query({"instance_ids": requested_ids}, actor))).all()
found_map = {inst.id: inst for inst in instances}
result = BatchInstancesActionOut(action=payload.action, requested=requested_ids, accepted=[], skipped=[], errors={})
accepted_ids: set[int] = set()
skipped_ids: set[int] = set()
for iid in requested_ids:
if iid not in found_map and iid not in skipped_ids:
result.skipped.append(iid)
result.errors[str(iid)] = "NOT_FOUND_OR_FORBIDDEN"
skipped_ids.add(iid)
job_action_map = {
"start": JobItemAction.START,
"stop": JobItemAction.STOP,
"reboot": JobItemAction.REBOOT,
"terminate": JobItemAction.TERMINATE,
}
sync_plan: dict[tuple[int, str, int | None], list[int]] = {}
for inst in instances:
if actor.role.name != RoleName.ADMIN.value and inst.customer_id != actor.customer_id:
if inst.id not in skipped_ids:
result.skipped.append(inst.id)
result.errors[str(inst.id)] = "NOT_FOUND_OR_FORBIDDEN"
skipped_ids.add(inst.id)
continue
if not inst.credential_id:
if inst.id not in skipped_ids:
result.skipped.append(inst.id)
result.errors[str(inst.id)] = "MISSING_CREDENTIAL"
skipped_ids.add(inst.id)
continue
try:
await ensure_credential_access(session, inst.credential_id, actor)
except HTTPException:
if inst.id not in skipped_ids:
result.skipped.append(inst.id)
result.errors[str(inst.id)] = "NOT_FOUND_OR_FORBIDDEN"
skipped_ids.add(inst.id)
continue
if payload.action == "sync":
key = (inst.credential_id, inst.region, inst.customer_id)
sync_plan.setdefault(key, []).append(inst.id)
continue
job_action = job_action_map.get(payload.action)
if not job_action:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported action")
try:
await enqueue_action(session, inst, job_action, actor)
if inst.id not in accepted_ids:
accepted_ids.add(inst.id)
result.accepted.append(inst.id)
except Exception as exc:
if inst.id not in skipped_ids:
skipped_ids.add(inst.id)
result.skipped.append(inst.id)
result.errors[str(inst.id)] = str(exc)
# handle sync actions per credential/region/customer
for key, inst_ids in sync_plan.items():
cred_id, region, customer_id = key
try:
await sync_instances(session, cred_id, region, actor, customer_id_override=customer_id)
for iid in inst_ids:
if iid not in accepted_ids:
accepted_ids.add(iid)
result.accepted.append(iid)
except Exception as exc:
for iid in inst_ids:
if iid not in skipped_ids:
skipped_ids.add(iid)
result.skipped.append(iid)
result.errors[str(iid)] = str(exc)
return result
async def batch_instances_by_ips(
session: AsyncSession,
payload: BatchInstancesByIpIn,
actor: User,
) -> dict:
ips = []
seen = set()
for raw in payload.ips:
for part in re.split(r"[\s,]+", (raw or "").strip()):
if not part or part in seen:
continue
seen.add(part)
ips.append(part)
if not ips:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="ips required")
if len(ips) > MAX_BATCH_COUNT:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOO_MANY_ITEMS")
ip_set = set(ips)
filters = {
"credential_id": payload.credential_id,
"region": payload.region,
}
query = build_instances_query(filters, actor).where(
or_(Instance.public_ip.in_(ip_set), Instance.private_ip.in_(ip_set))
)
instances = (await session.scalars(query)).all()
matched_ids = [inst.id for inst in instances]
matched_ips: set[str] = set()
for inst in instances:
if inst.public_ip and inst.public_ip in ip_set:
matched_ips.add(inst.public_ip)
if inst.private_ip and inst.private_ip in ip_set:
matched_ips.add(inst.private_ip)
result = await batch_instances_action(
session,
BatchInstancesActionIn(instance_ids=matched_ids, action=payload.action),
actor,
)
return {
"ips_requested": ips,
"ips_matched": list(matched_ips),
"ips_unmatched": [ip for ip in ips if ip not in matched_ips],
"result": result,
}
def _fmt(dt: Optional[datetime]) -> str:
if not dt:
return ""
if dt.tzinfo:
return dt.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
return dt.strftime("%Y-%m-%d %H:%M:%S")
def _fmt_gib(val: float | None) -> str:
if val is None:
return ""
try:
as_float = float(val)
except (TypeError, ValueError):
return str(val)
if as_float.is_integer():
return str(int(as_float))
return f"{as_float:.2f}".rstrip("0").rstrip(".")
def _build_export_sheet(instances: List[Instance]) -> bytes:
wb = Workbook()
ws = wb.active
ws.title = "Instances"
headers = [
"Name",
"Credential",
"Owner",
"Instance ID",
"Type",
"Instance Type Specs",
"AMI / OS",
"OS Pretty",
"Public IP",
"Private IP",
"Region",
"AZ",
"Status",
"Desired Status",
"Account ID",
"Credential ID",
"Security Groups",
"Subnet ID",
"VPC ID",
"Launched At",
"Created At",
"Last Sync",
]
ws.append(headers)
for inst in instances:
os_pretty = inst.os_pretty_name or _derive_os_pretty_name(inst)
os_label = os_pretty or inst.os_name
if not os_label and isinstance(inst.last_cloud_state, dict):
os_label = inst.last_cloud_state.get("os_family") or inst.last_cloud_state.get("ami_name")
ami_os = " / ".join([x for x in [os_label, inst.ami_id] if x])
sg_val = ""
if isinstance(inst.security_groups, list):
sg_val = ", ".join([str(sg) for sg in inst.security_groups])
elif isinstance(inst.security_groups, dict):
sg_val = ", ".join([str(v) for v in inst.security_groups.values()])
spec_parts = []
if inst.instance_vcpus is not None:
spec_parts.append(f"{inst.instance_vcpus} vCPU")
if inst.instance_memory_gib is not None:
spec_parts.append(f"{_fmt_gib(inst.instance_memory_gib)} GiB")
if inst.instance_network_perf:
spec_parts.append(inst.instance_network_perf)
type_specs = inst.instance_type
if spec_parts:
type_specs = f"{inst.instance_type} ({', '.join(spec_parts)})" if inst.instance_type else ", ".join(spec_parts)
ws.append(
[
inst.name_tag or "",
getattr(inst, "credential_label", None)
or getattr(inst, "credential_name", None)
or (inst.credential.name if getattr(inst, "credential", None) else ""),
getattr(inst, "owner_name", None) or getattr(inst, "customer_name", None) or "",
inst.instance_id,
inst.instance_type,
type_specs,
ami_os,
os_pretty or "",
inst.public_ip or "",
inst.private_ip or "",
inst.region,
inst.az or "",
inst.status.value if isinstance(inst.status, InstanceStatus) else str(inst.status),
inst.desired_status.value if isinstance(inst.desired_status, InstanceDesiredStatus) else inst.desired_status or "",
inst.account_id,
inst.credential_id or "",
sg_val,
inst.subnet_id or "",
inst.vpc_id or "",
_fmt(inst.launched_at),
_fmt(inst.created_at),
_fmt(inst.last_sync),
]
)
buffer = BytesIO()
wb.save(buffer)
buffer.seek(0)
return buffer.getvalue()
async def export_instances(
session: AsyncSession,
filters: dict,
actor: User,
) -> bytes:
filters = dict(filters or {})
filters.pop("offset", None)
filters.pop("limit", None)
query = build_instances_query(filters, actor).order_by(Instance.updated_at.desc())
rows = (await session.scalars(query)).all()
if len(rows) > MAX_EXPORT_ROWS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"EXPORT_LIMIT_EXCEEDED:{MAX_EXPORT_ROWS}"
)
await enrich_instances(rows)
return _build_export_sheet(rows)
async def export_instances_by_ids(
session: AsyncSession,
instance_ids: List[int],
actor: User,
) -> bytes:
if len(instance_ids) > MAX_EXPORT_IDS:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOO_MANY_ITEMS")
ids = list(dict.fromkeys(instance_ids))
if not ids:
return _build_export_sheet([])
query = build_instances_query({}, actor).where(Instance.id.in_(ids)).order_by(Instance.updated_at.desc())
rows = (await session.scalars(query)).all()
if len(rows) > MAX_EXPORT_ROWS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"EXPORT_LIMIT_EXCEEDED:{MAX_EXPORT_ROWS}"
)
await enrich_instances(rows)
return _build_export_sheet(rows)