aws-mt5/db.py
2026-01-05 15:33:08 +08:00

405 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from typing import Iterable, Optional, List, Dict
from sqlalchemy import Column, DateTime, Integer, String, Float, create_engine, select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import declarative_base, sessionmaker
DATABASE_URL = os.getenv("DATABASE_URL", "mysql+pymysql://username:password@localhost:3306/ip_ops")
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
Base = declarative_base()
def _now_cn() -> datetime:
return datetime.now(timezone(timedelta(hours=8)))
class AWSAccount(Base):
__tablename__ = "aws_accounts"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(128), unique=True, nullable=False, index=True)
region = Column(String(64), nullable=False)
access_key_id = Column(String(128), nullable=False)
secret_access_key = Column(String(256), nullable=False)
ami_id = Column(String(128), nullable=False)
subnet_id = Column(String(128), nullable=True)
security_group_ids = Column(String(512), nullable=True)
key_name = Column(String(128), nullable=True)
created_at = Column(DateTime(timezone=True), nullable=False, default=_now_cn)
updated_at = Column(DateTime(timezone=True), nullable=False, default=_now_cn, onupdate=_now_cn)
class IPOperation(Base):
__tablename__ = "ip_operations"
id = Column(Integer, primary_key=True, autoincrement=True)
ip_address = Column(String(64), unique=True, nullable=False, index=True)
note = Column(String(255), nullable=True)
class IPAccountMapping(Base):
__tablename__ = "ip_account_mapping"
id = Column(Integer, primary_key=True, autoincrement=True)
ip_address = Column(String(64), unique=True, nullable=False, index=True)
account_name = Column(String(128), nullable=False)
class ServerSpec(Base):
__tablename__ = "server_specs"
id = Column(Integer, primary_key=True, autoincrement=True)
ip_address = Column(String(64), unique=True, nullable=False, index=True)
account_name = Column(String(128), nullable=False)
instance_type = Column(String(64), nullable=True)
instance_name = Column(String(255), nullable=True)
volume_type = Column(String(64), nullable=True)
security_group_names = Column(String(512), nullable=True)
security_group_ids = Column(String(512), nullable=True)
region = Column(String(64), nullable=True)
subnet_id = Column(String(128), nullable=True)
availability_zone = Column(String(64), nullable=True)
created_at = Column(DateTime(timezone=True), nullable=False)
class IPReplacementHistory(Base):
__tablename__ = "ip_replacement_history"
id = Column(Integer, primary_key=True, autoincrement=True)
old_ip = Column(String(64), nullable=False, index=True)
new_ip = Column(String(64), nullable=False, index=True)
account_name = Column(String(128), nullable=False)
group_id = Column(String(128), nullable=True, index=True)
terminated_network_out_mb = Column(Float, nullable=True)
created_at = Column(DateTime(timezone=True), nullable=False)
def resolve_group_id(old_ip: str) -> str:
"""Group id继承上一条 new_ip=old_ip 的记录,否则用 old_ip 作为新的组标识。"""
with db_session() as session:
prev = session.scalar(
select(IPReplacementHistory.group_id)
.where(IPReplacementHistory.new_ip == old_ip)
.order_by(IPReplacementHistory.id.desc())
)
return prev or old_ip
def init_db() -> None:
Base.metadata.create_all(bind=engine)
def _split_csv(value: Optional[str]) -> List[str]:
if not value:
return []
return [part for part in value.split(",") if part]
@contextmanager
def db_session():
session = SessionLocal()
try:
yield session
session.commit()
except SQLAlchemyError:
session.rollback()
raise
finally:
session.close()
def load_disallowed_ips() -> set[str]:
with db_session() as session:
rows: Iterable[IPOperation] = session.scalars(select(IPOperation.ip_address))
return {row for row in rows}
def list_aws_accounts() -> List[Dict[str, object]]:
with db_session() as session:
rows: Iterable[AWSAccount] = session.scalars(
select(AWSAccount).order_by(AWSAccount.id.desc())
)
return [
{
"name": row.name,
"region": row.region,
"access_key_id": row.access_key_id,
"secret_access_key": row.secret_access_key,
"ami_id": row.ami_id,
"subnet_id": row.subnet_id,
"security_group_ids": _split_csv(row.security_group_ids),
"key_name": row.key_name,
"created_at": row.created_at,
"updated_at": row.updated_at,
}
for row in rows
]
def upsert_aws_account(
*,
name: str,
region: str,
access_key_id: str,
secret_access_key: str,
ami_id: str,
subnet_id: Optional[str] = None,
security_group_ids: Optional[List[str]] = None,
key_name: Optional[str] = None,
) -> None:
with db_session() as session:
record = session.scalar(select(AWSAccount).where(AWSAccount.name == name))
payload = {
"region": region,
"access_key_id": access_key_id,
"secret_access_key": secret_access_key,
"ami_id": ami_id,
"subnet_id": subnet_id,
"security_group_ids": ",".join(security_group_ids or []),
"key_name": key_name,
}
if record:
for key, val in payload.items():
setattr(record, key, val)
else:
session.add(AWSAccount(name=name, **payload))
def delete_aws_account(name: str) -> None:
with db_session() as session:
record = session.scalar(select(AWSAccount).where(AWSAccount.name == name))
if record:
session.delete(record)
def get_account_by_ip(ip: str) -> Optional[str]:
with db_session() as session:
return session.scalar(
select(IPAccountMapping.account_name).where(IPAccountMapping.ip_address == ip)
)
def list_account_mappings() -> List[Dict[str, str]]:
with db_session() as session:
rows: Iterable[IPAccountMapping] = session.scalars(
select(IPAccountMapping).order_by(IPAccountMapping.id.desc())
)
return [{"ip_address": row.ip_address, "account_name": row.account_name} for row in rows]
def upsert_account_mapping(ip: str, account_name: str) -> None:
with db_session() as session:
record = session.scalar(
select(IPAccountMapping).where(IPAccountMapping.ip_address == ip)
)
if record:
record.account_name = account_name
else:
session.add(IPAccountMapping(ip_address=ip, account_name=account_name))
def delete_account_mapping(ip: str) -> None:
with db_session() as session:
record = session.scalar(
select(IPAccountMapping).where(IPAccountMapping.ip_address == ip)
)
if record:
session.delete(record)
def update_ip_account_mapping(old_ip: str, new_ip: str, account_name: str) -> None:
with db_session() as session:
existing_mapping = session.scalar(
select(IPAccountMapping).where(IPAccountMapping.ip_address == old_ip)
)
conflict_mapping = session.scalar(
select(IPAccountMapping).where(IPAccountMapping.ip_address == new_ip)
)
if conflict_mapping and (not existing_mapping or conflict_mapping.id != existing_mapping.id):
raise ValueError(f"IP {new_ip} 已经映射到账户 {conflict_mapping.account_name}")
if existing_mapping:
existing_mapping.ip_address = new_ip
existing_mapping.account_name = account_name
else:
session.add(IPAccountMapping(ip_address=new_ip, account_name=account_name))
def upsert_server_spec(
*,
ip_address: str,
account_name: str,
instance_type: Optional[str],
instance_name: Optional[str],
volume_type: Optional[str],
security_group_names: List[str],
security_group_ids: List[str],
region: Optional[str],
subnet_id: Optional[str],
availability_zone: Optional[str],
created_at: Optional[datetime] = None,
) -> None:
with db_session() as session:
spec = session.scalar(select(ServerSpec).where(ServerSpec.ip_address == ip_address))
payload = {
"account_name": account_name,
"instance_type": instance_type,
"instance_name": instance_name,
"volume_type": volume_type,
"security_group_names": ",".join(security_group_names),
"security_group_ids": ",".join(security_group_ids),
"region": region,
"subnet_id": subnet_id,
"availability_zone": availability_zone,
"created_at": created_at or _now_cn(),
}
if spec:
for key, val in payload.items():
setattr(spec, key, val)
else:
session.add(ServerSpec(ip_address=ip_address, **payload))
def get_server_spec(ip_address: str) -> Optional[Dict[str, Optional[str]]]:
with db_session() as session:
spec = session.scalar(select(ServerSpec).where(ServerSpec.ip_address == ip_address))
if not spec:
return None
return {
"ip_address": spec.ip_address,
"account_name": spec.account_name,
"instance_type": spec.instance_type,
"instance_name": spec.instance_name,
"volume_type": spec.volume_type,
"security_group_names": spec.security_group_names.split(",") if spec.security_group_names else [],
"security_group_ids": spec.security_group_ids.split(",") if spec.security_group_ids else [],
"region": spec.region,
"subnet_id": spec.subnet_id,
"availability_zone": spec.availability_zone,
"created_at": spec.created_at,
}
def add_replacement_history(
old_ip: str,
new_ip: str,
account_name: str,
group_id: Optional[str],
terminated_network_out_mb: Optional[float] = None,
) -> None:
resolved_group = group_id or resolve_group_id(old_ip)
with db_session() as session:
session.add(
IPReplacementHistory(
old_ip=old_ip,
new_ip=new_ip,
account_name=account_name,
group_id=resolved_group,
terminated_network_out_mb=terminated_network_out_mb,
created_at=_now_cn(),
)
)
def get_replacement_history(limit: int = 50) -> List[Dict[str, str]]:
with db_session() as session:
rows: Iterable[IPReplacementHistory] = session.scalars(
select(IPReplacementHistory).order_by(IPReplacementHistory.id.desc()).limit(limit)
)
return [
{
"old_ip": row.old_ip,
"new_ip": row.new_ip,
"account_name": row.account_name,
"group_id": row.group_id,
"terminated_network_out_mb": row.terminated_network_out_mb,
"created_at": row.created_at.isoformat(),
}
for row in rows
]
def get_history_by_ip_or_group(ip: Optional[str], group_id: Optional[str], limit: int = 200) -> List[Dict[str, str]]:
with db_session() as session:
stmt = select(IPReplacementHistory).order_by(IPReplacementHistory.id.desc()).limit(limit)
if group_id:
stmt = stmt.where(IPReplacementHistory.group_id == group_id)
elif ip:
stmt = stmt.where(
(IPReplacementHistory.old_ip == ip) | (IPReplacementHistory.new_ip == ip)
)
rows: Iterable[IPReplacementHistory] = session.scalars(stmt)
return [
{
"old_ip": row.old_ip,
"new_ip": row.new_ip,
"account_name": row.account_name,
"group_id": row.group_id,
"terminated_network_out_mb": row.terminated_network_out_mb,
"created_at": row.created_at.isoformat(),
}
for row in rows
]
def get_history_chains(ip: Optional[str] = None, group_id: Optional[str] = None, limit: int = 500) -> List[Dict[str, object]]:
"""返回按 group_id 聚合的链路信息(按创建时间升序构建链)。"""
with db_session() as session:
stmt = select(IPReplacementHistory).order_by(IPReplacementHistory.created_at.asc())
if group_id:
stmt = stmt.where(IPReplacementHistory.group_id == group_id)
elif ip:
stmt = stmt.where(
(IPReplacementHistory.old_ip == ip) | (IPReplacementHistory.new_ip == ip)
)
stmt = stmt.limit(limit)
rows: Iterable[IPReplacementHistory] = session.scalars(stmt)
groups: Dict[str, Dict[str, object]] = {}
for row in rows:
gid = row.group_id or row.old_ip
if gid not in groups:
groups[gid] = {"group_id": gid, "items": [], "chain": [], "first_ip_start": None}
entry = {
"old_ip": row.old_ip,
"new_ip": row.new_ip,
"account_name": row.account_name,
"terminated_network_out_mb": row.terminated_network_out_mb,
"created_at": row.created_at.isoformat(),
}
groups[gid]["items"].append(entry)
# 构建链路
for gid, data in groups.items():
items = data["items"]
items.sort(key=lambda x: x["created_at"])
chain: List[str] = []
for it in items:
if not chain:
chain.append(it["old_ip"])
if chain[-1] != it["old_ip"] and it["old_ip"] not in chain:
chain.append(it["old_ip"])
if chain[-1] != it["new_ip"]:
chain.append(it["new_ip"])
data["chain"] = chain
# 读取链首 IP 的创建时间server_specs.created_at
if chain:
first_ip = chain[0]
spec_time = session.scalar(
select(ServerSpec.created_at).where(ServerSpec.ip_address == first_ip)
)
if spec_time:
data["first_ip_start"] = spec_time.isoformat()
# 返回按最早时间排序的组
ordered = sorted(
groups.values(),
key=lambda g: g["items"][0]["created_at"] if g["items"] else "",
)
return ordered