audoWin/autodemo/infer.py
2025-12-19 16:24:04 +08:00

397 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# MIT License
# Copyright (c) 2024
"""多模态归纳:读取 session 目录,组装提示,调用 LLM生成 DSL"""
from __future__ import annotations
import argparse
import base64
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests # type: ignore
try:
# 优先使用 python-dotenv缺失则退回手动解析
from dotenv import load_dotenv # type: ignore
except Exception:
load_dotenv = None
from .prompt_templates import SYSTEM_PROMPT, render_user_prompt
from .schema import DSLSpec, EventRecord, FramePaths, UISnapshot, UISelector
# --------- Pydantic v1/v2 兼容辅助 ---------
def _model_validate(cls, data: Any) -> Any:
if hasattr(cls, "model_validate"):
return cls.model_validate(data) # type: ignore[attr-defined]
return cls.parse_obj(data) # type: ignore[attr-defined]
def _model_dump(obj: Any, **kwargs: Any) -> Dict[str, Any]:
if hasattr(obj, "model_dump"):
return obj.model_dump(**kwargs) # type: ignore[attr-defined]
return obj.dict(**kwargs) # type: ignore[attr-defined]
def _load_env_file() -> None:
"""加载项目根目录的 .env优先使用 python-dotenv缺失则手工解析"""
env_path = Path(__file__).resolve().parent.parent / ".env"
if load_dotenv:
load_dotenv(env_path)
return
if not env_path.exists():
return
for line in env_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, val = line.split("=", 1)
os.environ.setdefault(key.strip(), val.strip())
def _coerce_assertions(spec_dict: Dict[str, Any]) -> Dict[str, Any]:
"""将 assertions 内的非字符串条目转换为字符串,防止验证失败"""
assertions = spec_dict.get("assertions")
if isinstance(assertions, list):
new_items = []
for item in assertions:
if isinstance(item, str):
new_items.append(item)
else:
try:
new_items.append(json.dumps(item, ensure_ascii=False))
except Exception:
new_items.append(str(item))
spec_dict["assertions"] = new_items
return spec_dict
def _strip_code_fences(text: str) -> str:
"""去除 ```json ... ``` 或 ``` ... ``` 包裹"""
stripped = text.strip()
if stripped.startswith("```"):
parts = stripped.split("```")
if len(parts) >= 3:
return parts[1].lstrip("json").strip() if parts[1].startswith("json") else parts[1].strip()
return stripped
def _normalize_steps(spec_dict: Dict[str, Any]) -> Dict[str, Any]:
"""规范化 steps 字段到 schema 支持的动作/字段"""
steps = spec_dict.get("steps")
if not isinstance(steps, list):
return spec_dict
normalized = []
for step in steps:
if not isinstance(step, dict):
continue
# 将 selector -> target
if "target" not in step and "selector" in step:
step["target"] = step["selector"]
step.pop("selector", None)
action = step.get("action")
# value -> text 归一化,兼容 set_value/type
if "value" in step and "text" not in step:
step["text"] = step.get("value")
step.pop("value", None)
# 处理 wait_for_window 自定义动作
if action == "wait_for_window":
title = step.pop("window_title_part", None)
timeout = step.pop("timeout", None)
step["action"] = "wait_for"
step["target"] = step.get("target") or {}
if title:
step["target"].setdefault("Name", title)
step["target"].setdefault("ControlType", "WindowControl")
if timeout:
secs = float(timeout) / 1000.0
step["waits"] = {"appear": secs, "disappear": 5.0}
# 若 action 不在允许列表,降级为 assert_exists
if step.get("action") not in {"click", "type", "set_value", "assert_exists", "wait_for"}:
step["action"] = "assert_exists"
# 标准化 ControlType 命名
tgt = step.get("target", {})
if isinstance(tgt, dict) and tgt.get("ControlType") == "Window":
tgt["ControlType"] = "WindowControl"
normalized.append(step)
spec_dict["steps"] = normalized
return spec_dict
# ---------------- LLM 抽象 ----------------
class LLMClient:
"""LLM 抽象接口"""
def generate(self, system_prompt: str, user_prompt: str, images: Optional[List[Dict[str, Any]]] = None) -> str:
raise NotImplementedError
class DummyLLM(LLMClient):
"""纯文本离线生成,基于事件启发式"""
def generate(self, system_prompt: str, user_prompt: str, images: Optional[List[Dict[str, Any]]] = None) -> str:
# 简单规则:点击 -> clicktext_input -> type若窗口标题包含记事本且有文本输入补保存按钮
data = json.loads(user_prompt.split("事件摘要(JSON)")[-1])
steps: List[Dict[str, Any]] = []
params: Dict[str, Any] = {}
assertions: List[str] = []
saw_text = False
saw_notepad = False
for ev in data:
ev_type = ev.get("event_type")
selector = ev.get("uia_selector") or {}
if ev_type == "mouse_click":
steps.append({"action": "click", "target": selector})
elif ev_type == "text_input":
saw_text = True
params.setdefault("text", ev.get("text", ""))
steps.append({"action": "type", "target": selector, "text": "{{text}}"})
if ev.get("window_title") and "记事本" in ev.get("window_title", ""):
saw_notepad = True
if saw_notepad and saw_text:
assertions.append("文本已输入记事本")
steps.append({"action": "click", "target": {"Name": "保存", "ControlType": "Button"}})
if not assertions:
assertions.append("关键控件存在")
spec = {
"params": params,
"steps": steps or [{"action": "assert_exists", "target": {"Name": "dummy"}}],
"assertions": assertions,
"retry_policy": {"max_attempts": 2, "interval": 1.0},
"waits": {"appear": 5.0, "disappear": 5.0},
}
return json.dumps(spec, ensure_ascii=False)
class OpenAIVisionClient(LLMClient):
"""兼容 OpenAI 接口的多模态客户端,支持自定义 base_url 和 model"""
def __init__(
self,
api_key: str,
model: str = "gpt-5.1-high",
base_url: str = "https://api.wgetai.com/v1",
timeout: float = 120.0,
retries: int = 1,
) -> None:
self.api_key = api_key
self.model = model
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.retries = max(0, retries)
def generate(self, system_prompt: str, user_prompt: str, images: Optional[List[Dict[str, Any]]] = None) -> str:
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
content: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}]
for img in images or []:
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img['b64']}"}})
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": content},
],
"temperature": 0.2,
}
url = f"{self.base_url}/chat/completions"
last_err: Optional[Exception] = None
for attempt in range(self.retries + 1):
try:
resp = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
return text
except Exception as exc: # noqa: BLE001
last_err = exc
if attempt < self.retries:
continue
raise
raise last_err or RuntimeError("LLM 调用失败")
# ---------------- 数据加载与压缩 ----------------
def _load_events(session_dir: Path) -> List[EventRecord]:
events_path = session_dir / "events.jsonl"
events: List[EventRecord] = []
with events_path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
events.append(_model_validate(EventRecord, json.loads(line)))
return events
def _load_snapshot(path: Optional[str]) -> Optional[UISnapshot]:
if not path:
return None
p = Path(path)
if not p.exists():
return None
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
return _model_validate(UISnapshot, data)
def _best_image(frame_paths: Optional[FramePaths]) -> Optional[str]:
if not frame_paths:
return None
for cand in [frame_paths.crop_element, frame_paths.crop_mouse, frame_paths.full]:
if cand and Path(cand).exists():
return cand
return None
def _selector_summary(selector: Optional[UISelector]) -> Dict[str, Any]:
if not selector:
return {}
return {
"AutomationId": selector.automation_id,
"Name": selector.name,
"ClassName": selector.class_name,
"ControlType": selector.control_type,
}
def _compress_tree(snapshot: Optional[UISnapshot], selector: Optional[UISelector]) -> List[Dict[str, Any]]:
"""压缩 UI 树:保留深度<=2或与命中控件同名/同类型的兄弟"""
if not snapshot:
return []
nodes = []
for node in snapshot.tree:
if node.depth <= 2:
nodes.append(_model_dump(node, exclude_none=True))
else:
if selector and (node.name == selector.name or node.control_type == selector.control_type):
nodes.append(_model_dump(node, exclude_none=True))
return nodes
def _encode_image_b64(path: Optional[str]) -> Optional[str]:
if not path:
return None
try:
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("ascii")
except Exception:
return None
def _pack_events(events: List[EventRecord], multimodal: bool) -> List[Dict[str, Any]]:
packed: List[Dict[str, Any]] = []
for ev in events:
if ev.event_type not in {"mouse_click", "text_input", "window_change"}:
continue
img_path = _best_image(ev.frame_paths)
snapshot = _load_snapshot(ev.ui_snapshot)
selector = ev.uia
tree = _compress_tree(snapshot, selector)
item: Dict[str, Any] = {
"event_type": ev.event_type,
"ts": ev.ts,
"video_time_offset_ms": ev.video_time_offset_ms,
"text": ev.text,
"window_title": ev.window.title if ev.window else None,
"window_process": ev.window.process_name if ev.window else None,
"uia_selector": _selector_summary(selector),
"uia_tree": tree,
"frame_path": img_path,
}
if multimodal and img_path:
b64 = _encode_image_b64(img_path)
if b64:
item["image_base64"] = b64
packed.append(item)
return packed
# ---------------- 主入口 ----------------
def infer_session(
session_dir: Path,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "gpt-5.1-high",
timeout: float = 120.0,
retries: int = 1,
) -> DSLSpec:
"""读取 session 目录,返回 DSLSpec"""
events = _load_events(session_dir)
multimodal = api_key is not None
packed = _pack_events(events, multimodal=multimodal)
user_prompt = render_user_prompt(packed)
client: LLMClient
images_payload = [{"b64": e["image_base64"]} for e in packed if "image_base64" in e] if multimodal else None
raw: str
if multimodal:
client = OpenAIVisionClient(
api_key=api_key,
base_url=base_url or "https://api.wgetai.com/v1",
model=model,
timeout=timeout,
retries=retries,
)
try:
raw = client.generate(SYSTEM_PROMPT, user_prompt, images=images_payload)
except Exception as exc: # noqa: BLE001
print(f"[warn] 多模态归纳失败,降级为文本-only原因: {exc}")
client = DummyLLM()
raw = client.generate(SYSTEM_PROMPT, user_prompt, images=None)
else:
client = DummyLLM()
raw = client.generate(SYSTEM_PROMPT, user_prompt, images=None)
if not raw or not raw.strip():
raise RuntimeError("LLM 返回为空,无法解析为 JSON")
cleaned = _strip_code_fences(raw)
try:
spec_dict = json.loads(cleaned)
except Exception as exc:
preview = cleaned[:500]
raise RuntimeError(f"LLM 返回非 JSON可见前 500 字符: {preview}") from exc
spec_dict = _coerce_assertions(spec_dict)
spec_dict = _normalize_steps(spec_dict)
return _model_validate(DSLSpec, spec_dict)
def main() -> None:
parser = argparse.ArgumentParser(description="从 session 目录归纳 DSL支持多模态")
parser.add_argument("--session-dir", type=str, required=True, help="session 目录,包含 events.jsonl / manifest.json / frames / ui_snapshots")
parser.add_argument("--out", type=str, default="dsl.json", help="输出 DSL JSON 路径")
parser.add_argument("--api-key", type=str, help="LLM API Key缺省读取环境变量 OPENAI_API_KEY")
parser.add_argument("--base-url", type=str, default="https://api.wgetai.com/v1", help="LLM Base URL")
parser.add_argument("--model", type=str, default="gpt-5.1-high", help="LLM 模型名")
parser.add_argument("--timeout", type=float, default=120.0, help="LLM 请求超时时间(秒)")
parser.add_argument("--retries", type=int, default=1, help="LLM 请求重试次数(额外重试次数)")
args = parser.parse_args()
_load_env_file()
session_dir = Path(args.session_dir)
api_key = args.api_key or os.environ.get("OPENAI_API_KEY")
base_url = args.base_url or os.environ.get("OPENAI_BASE_URL")
spec = infer_session(
session_dir,
api_key=api_key,
base_url=base_url,
model=args.model,
timeout=args.timeout,
retries=args.retries,
)
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
f.write(json.dumps(_model_dump(spec), ensure_ascii=False, indent=2))
print(f"DSL 写入: {out_path}")
if __name__ == "__main__":
main()