297 lines
11 KiB
Python
297 lines
11 KiB
Python
import os
|
||
from contextlib import contextmanager
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Iterable, Optional, List, Dict
|
||
|
||
from sqlalchemy import Column, DateTime, Integer, String, Float, create_engine, select
|
||
from sqlalchemy.exc import SQLAlchemyError
|
||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||
|
||
DATABASE_URL = os.getenv("DATABASE_URL", "mysql+pymysql://username:password@localhost:3306/ip_ops")
|
||
|
||
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
|
||
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||
Base = declarative_base()
|
||
|
||
|
||
class IPOperation(Base):
|
||
__tablename__ = "ip_operations"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
ip_address = Column(String(64), unique=True, nullable=False, index=True)
|
||
note = Column(String(255), nullable=True)
|
||
|
||
|
||
class IPAccountMapping(Base):
|
||
__tablename__ = "ip_account_mapping"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
ip_address = Column(String(64), unique=True, nullable=False, index=True)
|
||
account_name = Column(String(128), nullable=False)
|
||
|
||
|
||
class ServerSpec(Base):
|
||
__tablename__ = "server_specs"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
ip_address = Column(String(64), unique=True, nullable=False, index=True)
|
||
account_name = Column(String(128), nullable=False)
|
||
instance_type = Column(String(64), nullable=True)
|
||
instance_name = Column(String(255), nullable=True)
|
||
volume_type = Column(String(64), nullable=True)
|
||
security_group_names = Column(String(512), nullable=True)
|
||
security_group_ids = Column(String(512), nullable=True)
|
||
region = Column(String(64), nullable=True)
|
||
subnet_id = Column(String(128), nullable=True)
|
||
availability_zone = Column(String(64), nullable=True)
|
||
created_at = Column(DateTime(timezone=True), nullable=False)
|
||
|
||
|
||
class IPReplacementHistory(Base):
|
||
__tablename__ = "ip_replacement_history"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
old_ip = Column(String(64), nullable=False, index=True)
|
||
new_ip = Column(String(64), nullable=False, index=True)
|
||
account_name = Column(String(128), nullable=False)
|
||
group_id = Column(String(128), nullable=True, index=True)
|
||
terminated_network_out_mb = Column(Float, nullable=True)
|
||
created_at = Column(DateTime(timezone=True), nullable=False)
|
||
|
||
|
||
def resolve_group_id(old_ip: str) -> str:
|
||
"""Group id继承上一条 new_ip=old_ip 的记录,否则用 old_ip 作为新的组标识。"""
|
||
with db_session() as session:
|
||
prev = session.scalar(
|
||
select(IPReplacementHistory.group_id)
|
||
.where(IPReplacementHistory.new_ip == old_ip)
|
||
.order_by(IPReplacementHistory.id.desc())
|
||
)
|
||
return prev or old_ip
|
||
|
||
|
||
def init_db() -> None:
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
|
||
@contextmanager
|
||
def db_session():
|
||
session = SessionLocal()
|
||
try:
|
||
yield session
|
||
session.commit()
|
||
except SQLAlchemyError:
|
||
session.rollback()
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
def load_disallowed_ips() -> set[str]:
|
||
with db_session() as session:
|
||
rows: Iterable[IPOperation] = session.scalars(select(IPOperation.ip_address))
|
||
return {row for row in rows}
|
||
|
||
|
||
def get_account_by_ip(ip: str) -> Optional[str]:
|
||
with db_session() as session:
|
||
return session.scalar(
|
||
select(IPAccountMapping.account_name).where(IPAccountMapping.ip_address == ip)
|
||
)
|
||
|
||
|
||
def update_ip_account_mapping(old_ip: str, new_ip: str, account_name: str) -> None:
|
||
with db_session() as session:
|
||
existing_mapping = session.scalar(
|
||
select(IPAccountMapping).where(IPAccountMapping.ip_address == old_ip)
|
||
)
|
||
conflict_mapping = session.scalar(
|
||
select(IPAccountMapping).where(IPAccountMapping.ip_address == new_ip)
|
||
)
|
||
if conflict_mapping and (not existing_mapping or conflict_mapping.id != existing_mapping.id):
|
||
raise ValueError(f"IP {new_ip} 已经映射到账户 {conflict_mapping.account_name}")
|
||
|
||
if existing_mapping:
|
||
existing_mapping.ip_address = new_ip
|
||
existing_mapping.account_name = account_name
|
||
else:
|
||
session.add(IPAccountMapping(ip_address=new_ip, account_name=account_name))
|
||
|
||
|
||
def _now_cn() -> datetime:
|
||
return datetime.now(timezone(timedelta(hours=8)))
|
||
|
||
|
||
def upsert_server_spec(
|
||
*,
|
||
ip_address: str,
|
||
account_name: str,
|
||
instance_type: Optional[str],
|
||
instance_name: Optional[str],
|
||
volume_type: Optional[str],
|
||
security_group_names: List[str],
|
||
security_group_ids: List[str],
|
||
region: Optional[str],
|
||
subnet_id: Optional[str],
|
||
availability_zone: Optional[str],
|
||
created_at: Optional[datetime] = None,
|
||
) -> None:
|
||
with db_session() as session:
|
||
spec = session.scalar(select(ServerSpec).where(ServerSpec.ip_address == ip_address))
|
||
payload = {
|
||
"account_name": account_name,
|
||
"instance_type": instance_type,
|
||
"instance_name": instance_name,
|
||
"volume_type": volume_type,
|
||
"security_group_names": ",".join(security_group_names),
|
||
"security_group_ids": ",".join(security_group_ids),
|
||
"region": region,
|
||
"subnet_id": subnet_id,
|
||
"availability_zone": availability_zone,
|
||
"created_at": created_at or _now_cn(),
|
||
}
|
||
if spec:
|
||
for key, val in payload.items():
|
||
setattr(spec, key, val)
|
||
else:
|
||
session.add(ServerSpec(ip_address=ip_address, **payload))
|
||
|
||
|
||
def get_server_spec(ip_address: str) -> Optional[Dict[str, Optional[str]]]:
|
||
with db_session() as session:
|
||
spec = session.scalar(select(ServerSpec).where(ServerSpec.ip_address == ip_address))
|
||
if not spec:
|
||
return None
|
||
return {
|
||
"ip_address": spec.ip_address,
|
||
"account_name": spec.account_name,
|
||
"instance_type": spec.instance_type,
|
||
"instance_name": spec.instance_name,
|
||
"volume_type": spec.volume_type,
|
||
"security_group_names": spec.security_group_names.split(",") if spec.security_group_names else [],
|
||
"security_group_ids": spec.security_group_ids.split(",") if spec.security_group_ids else [],
|
||
"region": spec.region,
|
||
"subnet_id": spec.subnet_id,
|
||
"availability_zone": spec.availability_zone,
|
||
"created_at": spec.created_at,
|
||
}
|
||
|
||
|
||
def add_replacement_history(
|
||
old_ip: str,
|
||
new_ip: str,
|
||
account_name: str,
|
||
group_id: Optional[str],
|
||
terminated_network_out_mb: Optional[float] = None,
|
||
) -> None:
|
||
resolved_group = group_id or resolve_group_id(old_ip)
|
||
with db_session() as session:
|
||
session.add(
|
||
IPReplacementHistory(
|
||
old_ip=old_ip,
|
||
new_ip=new_ip,
|
||
account_name=account_name,
|
||
group_id=resolved_group,
|
||
terminated_network_out_mb=terminated_network_out_mb,
|
||
created_at=_now_cn(),
|
||
)
|
||
)
|
||
|
||
|
||
def get_replacement_history(limit: int = 50) -> List[Dict[str, str]]:
|
||
with db_session() as session:
|
||
rows: Iterable[IPReplacementHistory] = session.scalars(
|
||
select(IPReplacementHistory).order_by(IPReplacementHistory.id.desc()).limit(limit)
|
||
)
|
||
return [
|
||
{
|
||
"old_ip": row.old_ip,
|
||
"new_ip": row.new_ip,
|
||
"account_name": row.account_name,
|
||
"group_id": row.group_id,
|
||
"terminated_network_out_mb": row.terminated_network_out_mb,
|
||
"created_at": row.created_at.isoformat(),
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
|
||
def get_history_by_ip_or_group(ip: Optional[str], group_id: Optional[str], limit: int = 200) -> List[Dict[str, str]]:
|
||
with db_session() as session:
|
||
stmt = select(IPReplacementHistory).order_by(IPReplacementHistory.id.desc()).limit(limit)
|
||
if group_id:
|
||
stmt = stmt.where(IPReplacementHistory.group_id == group_id)
|
||
elif ip:
|
||
stmt = stmt.where(
|
||
(IPReplacementHistory.old_ip == ip) | (IPReplacementHistory.new_ip == ip)
|
||
)
|
||
rows: Iterable[IPReplacementHistory] = session.scalars(stmt)
|
||
return [
|
||
{
|
||
"old_ip": row.old_ip,
|
||
"new_ip": row.new_ip,
|
||
"account_name": row.account_name,
|
||
"group_id": row.group_id,
|
||
"terminated_network_out_mb": row.terminated_network_out_mb,
|
||
"created_at": row.created_at.isoformat(),
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
|
||
def get_history_chains(ip: Optional[str] = None, group_id: Optional[str] = None, limit: int = 500) -> List[Dict[str, object]]:
|
||
"""返回按 group_id 聚合的链路信息(按创建时间升序构建链)。"""
|
||
with db_session() as session:
|
||
stmt = select(IPReplacementHistory).order_by(IPReplacementHistory.created_at.asc())
|
||
if group_id:
|
||
stmt = stmt.where(IPReplacementHistory.group_id == group_id)
|
||
elif ip:
|
||
stmt = stmt.where(
|
||
(IPReplacementHistory.old_ip == ip) | (IPReplacementHistory.new_ip == ip)
|
||
)
|
||
stmt = stmt.limit(limit)
|
||
rows: Iterable[IPReplacementHistory] = session.scalars(stmt)
|
||
|
||
groups: Dict[str, Dict[str, object]] = {}
|
||
for row in rows:
|
||
gid = row.group_id or row.old_ip
|
||
if gid not in groups:
|
||
groups[gid] = {"group_id": gid, "items": [], "chain": [], "first_ip_start": None}
|
||
entry = {
|
||
"old_ip": row.old_ip,
|
||
"new_ip": row.new_ip,
|
||
"account_name": row.account_name,
|
||
"terminated_network_out_mb": row.terminated_network_out_mb,
|
||
"created_at": row.created_at.isoformat(),
|
||
}
|
||
groups[gid]["items"].append(entry)
|
||
|
||
# 构建链路
|
||
for gid, data in groups.items():
|
||
items = data["items"]
|
||
items.sort(key=lambda x: x["created_at"])
|
||
chain: List[str] = []
|
||
for it in items:
|
||
if not chain:
|
||
chain.append(it["old_ip"])
|
||
if chain[-1] != it["old_ip"] and it["old_ip"] not in chain:
|
||
chain.append(it["old_ip"])
|
||
if chain[-1] != it["new_ip"]:
|
||
chain.append(it["new_ip"])
|
||
data["chain"] = chain
|
||
# 读取链首 IP 的创建时间(server_specs.created_at)
|
||
if chain:
|
||
first_ip = chain[0]
|
||
spec_time = session.scalar(
|
||
select(ServerSpec.created_at).where(ServerSpec.ip_address == first_ip)
|
||
)
|
||
if spec_time:
|
||
data["first_ip_start"] = spec_time.isoformat()
|
||
|
||
# 返回按最早时间排序的组
|
||
ordered = sorted(
|
||
groups.values(),
|
||
key=lambda g: g["items"][0]["created_at"] if g["items"] else "",
|
||
)
|
||
return ordered
|