aws-mt5/aws_service.py
2026-01-05 15:33:08 +08:00

323 lines
11 KiB
Python

from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, TypedDict
import boto3
from botocore.exceptions import BotoCoreError, ClientError
class ConfigError(Exception):
pass
class AWSOperationError(Exception):
pass
class InstanceSpec(TypedDict, total=False):
instance_type: Optional[str]
instance_name: Optional[str]
root_device: Optional[str]
root_size: Optional[int]
root_volume_type: Optional[str]
security_group_ids: List[str]
security_group_names: List[str]
subnet_id: Optional[str]
availability_zone: Optional[str]
region: Optional[str]
@dataclass
class AccountConfig:
name: str
region: str
access_key_id: str
secret_access_key: str
ami_id: str
subnet_id: Optional[str] = None
security_group_ids: List[str] = field(default_factory=list)
key_name: Optional[str] = None
def ec2_client(account: AccountConfig):
return boto3.client(
"ec2",
region_name=account.region,
aws_access_key_id=account.access_key_id,
aws_secret_access_key=account.secret_access_key,
)
def cloudwatch_client(account: AccountConfig):
return boto3.client(
"cloudwatch",
region_name=account.region,
aws_access_key_id=account.access_key_id,
aws_secret_access_key=account.secret_access_key,
)
def _get_instance_by_ip(client, ip: str) -> Optional[dict]:
filters = [
{"Name": "instance-state-name", "Values": ["pending", "running", "stopping", "stopped"]},
]
for field in ["ip-address", "private-ip-address"]:
try:
resp = client.describe_instances(Filters=filters + [{"Name": field, "Values": [ip]}])
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to describe instances: {exc}") from exc
for reservation in resp.get("Reservations", []):
for instance in reservation.get("Instances", []):
return instance
return None
def _wait_for_state(client, instance_id: str, waiter_name: str) -> None:
waiter = client.get_waiter(waiter_name)
waiter.wait(InstanceIds=[instance_id])
def _get_root_volume_spec(client, instance: dict) -> tuple[Optional[str], Optional[int], Optional[str]]:
"""Return (device_name, size_gb, volume_type) for root volume if available."""
root_device_name = instance.get("RootDeviceName")
if not root_device_name:
return None, None, None
for mapping in instance.get("BlockDeviceMappings", []):
if mapping.get("DeviceName") != root_device_name:
continue
ebs = mapping.get("Ebs")
if not ebs:
return root_device_name, None, None
volume_id = ebs.get("VolumeId")
if not volume_id:
return root_device_name, None, None
try:
vol_resp = client.describe_volumes(VolumeIds=[volume_id])
volumes = vol_resp.get("Volumes", [])
if volumes:
volume = volumes[0]
return root_device_name, volume.get("Size"), volume.get("VolumeType")
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to read volume info for {volume_id}: {exc}") from exc
return root_device_name, None, None
def _extract_security_group_ids(instance: dict) -> List[str]:
groups = []
for g in instance.get("SecurityGroups", []):
gid = g.get("GroupId")
if gid:
groups.append(gid)
return groups
def _extract_security_group_names(instance: dict) -> List[str]:
groups = []
for g in instance.get("SecurityGroups", []):
name = g.get("GroupName")
if name:
groups.append(name)
return groups
def _extract_name_tag(instance: dict) -> Optional[str]:
for tag in instance.get("Tags", []) or []:
if tag.get("Key") == "Name":
return tag.get("Value")
return None
def _terminate_instance(client, instance_id: str, wait_for_completion: bool = True) -> None:
try:
client.terminate_instances(InstanceIds=[instance_id])
if wait_for_completion:
_wait_for_state(client, instance_id, "instance_terminated")
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to terminate instance {instance_id}: {exc}") from exc
def _build_block_device_mappings(
device_name: Optional[str], volume_size: Optional[int], volume_type: Optional[str]
) -> Optional[list]:
if not device_name:
return None
ebs = {"DeleteOnTermination": True}
if volume_type:
ebs["VolumeType"] = volume_type
if volume_size:
ebs["VolumeSize"] = volume_size
return [{"DeviceName": device_name, "Ebs": ebs}]
def _provision_instance(
client,
account: AccountConfig,
spec: InstanceSpec,
) -> str:
def _build_params(include_key: bool = True) -> dict:
params = {
"ImageId": account.ami_id,
"InstanceType": spec.get("instance_type"),
"MinCount": 1,
"MaxCount": 1,
}
if spec.get("instance_name"):
params["TagSpecifications"] = [
{
"ResourceType": "instance",
"Tags": [{"Key": "Name", "Value": spec["instance_name"]}],
}
]
subnet_id = spec.get("subnet_id")
if subnet_id:
params["SubnetId"] = subnet_id
security_group_ids = spec.get("security_group_ids")
if security_group_ids:
params["SecurityGroupIds"] = security_group_ids
block_mapping = _build_block_device_mappings(
spec.get("root_device"), spec.get("root_size"), spec.get("root_volume_type")
)
if block_mapping:
params["BlockDeviceMappings"] = block_mapping
if include_key and account.key_name:
params["KeyName"] = account.key_name
return params
def _run(params: dict) -> str:
resp = client.run_instances(**params)
instance_id = resp["Instances"][0]["InstanceId"]
_wait_for_state(client, instance_id, "instance_running")
return instance_id
try:
return _run(_build_params())
except ClientError as exc:
code = exc.response.get("Error", {}).get("Code") if hasattr(exc, "response") else None
if code == "InvalidKeyPair.NotFound" and account.key_name:
# fallback: retry without key pair
try:
return _run(_build_params(include_key=False))
except (ClientError, BotoCoreError) as exc2:
raise AWSOperationError(
f"Failed to create instance after removing missing key pair {account.key_name}: {exc2}"
) from exc
raise AWSOperationError(f"Failed to create instance: {exc}") from exc
except BotoCoreError as exc:
raise AWSOperationError(f"Failed to create instance: {exc}") from exc
def _get_public_ip(client, instance_id: str) -> str:
try:
resp = client.describe_instances(InstanceIds=[instance_id])
reservations = resp.get("Reservations", [])
if not reservations:
raise AWSOperationError("Instance not found when reading IP")
instance = reservations[0]["Instances"][0]
return instance.get("PublicIpAddress") or ""
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to fetch public IP: {exc}") from exc
def _recycle_ip_until_free(client, instance_id: str, banned_ips: set[str], retry_limit: int) -> str:
attempts = 0
while attempts < retry_limit:
current_ip = _get_public_ip(client, instance_id)
if current_ip and current_ip not in banned_ips:
return current_ip
try:
client.stop_instances(InstanceIds=[instance_id])
_wait_for_state(client, instance_id, "instance_stopped")
client.start_instances(InstanceIds=[instance_id])
_wait_for_state(client, instance_id, "instance_running")
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed while cycling IP: {exc}") from exc
attempts += 1
raise AWSOperationError("Reached retry limit while attempting to obtain a free IP")
def _get_network_out_mb(cw_client, instance_id: str, days: int = 30) -> float:
"""Fetch total NetworkOut over the past window (MB)."""
end = datetime.now(timezone.utc)
start = end - timedelta(days=days)
try:
resp = cw_client.get_metric_statistics(
Namespace="AWS/EC2",
MetricName="NetworkOut",
Dimensions=[{"Name": "InstanceId", "Value": instance_id}],
StartTime=start,
EndTime=end,
Period=3600 * 6, # 6 小时粒度,覆盖 30 天
Statistics=["Sum"],
)
datapoints = resp.get("Datapoints", [])
if not datapoints:
return 0.0
total_bytes = sum(dp.get("Sum", 0.0) for dp in datapoints)
return round(total_bytes / (1024 * 1024), 2)
except (ClientError, BotoCoreError) as exc:
raise AWSOperationError(f"Failed to fetch NetworkOut metrics: {exc}") from exc
def _build_spec_from_instance(client, instance: dict, account: AccountConfig) -> InstanceSpec:
instance_type = instance.get("InstanceType")
if not instance_type:
raise AWSOperationError("Failed to detect instance type from source instance")
root_device, root_size, root_volume_type = _get_root_volume_spec(client, instance)
return {
"instance_type": instance_type,
"instance_name": _extract_name_tag(instance),
"root_device": root_device,
"root_size": root_size,
"root_volume_type": root_volume_type,
"security_group_ids": _extract_security_group_ids(instance),
"security_group_names": _extract_security_group_names(instance),
"subnet_id": instance.get("SubnetId") or account.subnet_id,
"availability_zone": instance.get("Placement", {}).get("AvailabilityZone"),
"region": account.region,
}
def replace_instance_ip(
ip: str,
account: AccountConfig,
disallowed_ips: set[str],
retry_limit: int = 5,
fallback_spec: Optional[InstanceSpec] = None,
) -> Dict[str, object]:
client = ec2_client(account)
cw = cloudwatch_client(account)
instance = _get_instance_by_ip(client, ip)
spec: Optional[InstanceSpec] = None
instance_id: Optional[str] = None
network_out_mb: Optional[float] = None
if instance:
instance_id = instance["InstanceId"]
spec = _build_spec_from_instance(client, instance, account)
try:
network_out_mb = _get_network_out_mb(cw, instance_id)
except AWSOperationError:
network_out_mb = None
elif fallback_spec:
spec = fallback_spec
if not spec:
raise AWSOperationError(f"No instance found with IP {ip} 且数据库无该IP规格信息")
new_instance_id = _provision_instance(client, account, spec)
new_ip = _recycle_ip_until_free(client, new_instance_id, disallowed_ips, retry_limit)
if instance_id:
# 不阻塞新实例创建,终止旧实例但不等待完成
_terminate_instance(client, instance_id, wait_for_completion=False)
return {
"terminated_instance_id": instance_id,
"new_instance_id": new_instance_id,
"new_ip": new_ip,
"spec_used": spec,
"terminated_network_out_mb": network_out_mb,
}