# MIT License # Copyright (c) 2024 """多模态归纳:读取 session 目录,组装提示,调用 LLM,生成 DSL""" from __future__ import annotations import argparse import base64 import json import os from pathlib import Path from typing import Any, Dict, List, Optional import requests # type: ignore try: # 优先使用 python-dotenv,缺失则退回手动解析 from dotenv import load_dotenv # type: ignore except Exception: load_dotenv = None from .prompt_templates import SYSTEM_PROMPT, render_user_prompt from .schema import DSLSpec, EventRecord, FramePaths, UISnapshot, UISelector # --------- Pydantic v1/v2 兼容辅助 --------- def _model_validate(cls, data: Any) -> Any: if hasattr(cls, "model_validate"): return cls.model_validate(data) # type: ignore[attr-defined] return cls.parse_obj(data) # type: ignore[attr-defined] def _model_dump(obj: Any, **kwargs: Any) -> Dict[str, Any]: if hasattr(obj, "model_dump"): return obj.model_dump(**kwargs) # type: ignore[attr-defined] return obj.dict(**kwargs) # type: ignore[attr-defined] def _load_env_file() -> None: """加载项目根目录的 .env,优先使用 python-dotenv,缺失则手工解析""" env_path = Path(__file__).resolve().parent.parent / ".env" if load_dotenv: load_dotenv(env_path) return if not env_path.exists(): return for line in env_path.read_text(encoding="utf-8").splitlines(): line = line.strip() if not line or line.startswith("#") or "=" not in line: continue key, val = line.split("=", 1) os.environ.setdefault(key.strip(), val.strip()) def _coerce_assertions(spec_dict: Dict[str, Any]) -> Dict[str, Any]: """将 assertions 内的非字符串条目转换为字符串,防止验证失败""" assertions = spec_dict.get("assertions") if isinstance(assertions, list): new_items = [] for item in assertions: if isinstance(item, str): new_items.append(item) else: try: new_items.append(json.dumps(item, ensure_ascii=False)) except Exception: new_items.append(str(item)) spec_dict["assertions"] = new_items return spec_dict def _strip_code_fences(text: str) -> str: """去除 ```json ... ``` 或 ``` ... ``` 包裹""" stripped = text.strip() if stripped.startswith("```"): parts = stripped.split("```") if len(parts) >= 3: return parts[1].lstrip("json").strip() if parts[1].startswith("json") else parts[1].strip() return stripped def _normalize_steps(spec_dict: Dict[str, Any]) -> Dict[str, Any]: """规范化 steps 字段到 schema 支持的动作/字段""" steps = spec_dict.get("steps") if not isinstance(steps, list): return spec_dict normalized = [] for step in steps: if not isinstance(step, dict): continue # 将 selector -> target if "target" not in step and "selector" in step: step["target"] = step["selector"] step.pop("selector", None) action = step.get("action") # value -> text 归一化,兼容 set_value/type if "value" in step and "text" not in step: step["text"] = step.get("value") step.pop("value", None) # 处理 wait_for_window 自定义动作 if action == "wait_for_window": title = step.pop("window_title_part", None) timeout = step.pop("timeout", None) step["action"] = "wait_for" step["target"] = step.get("target") or {} if title: step["target"].setdefault("Name", title) step["target"].setdefault("ControlType", "WindowControl") if timeout: secs = float(timeout) / 1000.0 step["waits"] = {"appear": secs, "disappear": 5.0} # 若 action 不在允许列表,降级为 assert_exists if step.get("action") not in {"click", "type", "set_value", "assert_exists", "wait_for"}: step["action"] = "assert_exists" # 标准化 ControlType 命名 tgt = step.get("target", {}) if isinstance(tgt, dict) and tgt.get("ControlType") == "Window": tgt["ControlType"] = "WindowControl" normalized.append(step) spec_dict["steps"] = normalized return spec_dict # ---------------- LLM 抽象 ---------------- class LLMClient: """LLM 抽象接口""" def generate(self, system_prompt: str, user_prompt: str, images: Optional[List[Dict[str, Any]]] = None) -> str: raise NotImplementedError class DummyLLM(LLMClient): """纯文本离线生成,基于事件启发式""" def generate(self, system_prompt: str, user_prompt: str, images: Optional[List[Dict[str, Any]]] = None) -> str: # 简单规则:点击 -> click,text_input -> type;若窗口标题包含记事本且有文本输入,补保存按钮 data = json.loads(user_prompt.split("事件摘要(JSON):")[-1]) steps: List[Dict[str, Any]] = [] params: Dict[str, Any] = {} assertions: List[str] = [] saw_text = False saw_notepad = False for ev in data: ev_type = ev.get("event_type") selector = ev.get("uia_selector") or {} if ev_type == "mouse_click": steps.append({"action": "click", "target": selector}) elif ev_type == "text_input": saw_text = True params.setdefault("text", ev.get("text", "")) steps.append({"action": "type", "target": selector, "text": "{{text}}"}) if ev.get("window_title") and "记事本" in ev.get("window_title", ""): saw_notepad = True if saw_notepad and saw_text: assertions.append("文本已输入记事本") steps.append({"action": "click", "target": {"Name": "保存", "ControlType": "Button"}}) if not assertions: assertions.append("关键控件存在") spec = { "params": params, "steps": steps or [{"action": "assert_exists", "target": {"Name": "dummy"}}], "assertions": assertions, "retry_policy": {"max_attempts": 2, "interval": 1.0}, "waits": {"appear": 5.0, "disappear": 5.0}, } return json.dumps(spec, ensure_ascii=False) class OpenAIVisionClient(LLMClient): """兼容 OpenAI 接口的多模态客户端,支持自定义 base_url 和 model""" def __init__( self, api_key: str, model: str = "gpt-5.1-high", base_url: str = "https://api.wgetai.com/v1", timeout: float = 120.0, retries: int = 1, ) -> None: self.api_key = api_key self.model = model self.base_url = base_url.rstrip("/") self.timeout = timeout self.retries = max(0, retries) def generate(self, system_prompt: str, user_prompt: str, images: Optional[List[Dict[str, Any]]] = None) -> str: headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} content: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}] for img in images or []: content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img['b64']}"}}) payload = { "model": self.model, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": content}, ], "temperature": 0.2, } url = f"{self.base_url}/chat/completions" last_err: Optional[Exception] = None for attempt in range(self.retries + 1): try: resp = requests.post(url, headers=headers, json=payload, timeout=self.timeout) resp.raise_for_status() text = resp.json()["choices"][0]["message"]["content"] return text except Exception as exc: # noqa: BLE001 last_err = exc if attempt < self.retries: continue raise raise last_err or RuntimeError("LLM 调用失败") # ---------------- 数据加载与压缩 ---------------- def _load_events(session_dir: Path) -> List[EventRecord]: events_path = session_dir / "events.jsonl" events: List[EventRecord] = [] with events_path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue events.append(_model_validate(EventRecord, json.loads(line))) return events def _load_snapshot(path: Optional[str]) -> Optional[UISnapshot]: if not path: return None p = Path(path) if not p.exists(): return None with p.open("r", encoding="utf-8") as f: data = json.load(f) return _model_validate(UISnapshot, data) def _best_image(frame_paths: Optional[FramePaths]) -> Optional[str]: if not frame_paths: return None for cand in [frame_paths.crop_element, frame_paths.crop_mouse, frame_paths.full]: if cand and Path(cand).exists(): return cand return None def _selector_summary(selector: Optional[UISelector]) -> Dict[str, Any]: if not selector: return {} return { "AutomationId": selector.automation_id, "Name": selector.name, "ClassName": selector.class_name, "ControlType": selector.control_type, } def _compress_tree(snapshot: Optional[UISnapshot], selector: Optional[UISelector]) -> List[Dict[str, Any]]: """压缩 UI 树:保留深度<=2,或与命中控件同名/同类型的兄弟""" if not snapshot: return [] nodes = [] for node in snapshot.tree: if node.depth <= 2: nodes.append(_model_dump(node, exclude_none=True)) else: if selector and (node.name == selector.name or node.control_type == selector.control_type): nodes.append(_model_dump(node, exclude_none=True)) return nodes def _encode_image_b64(path: Optional[str]) -> Optional[str]: if not path: return None try: with open(path, "rb") as f: return base64.b64encode(f.read()).decode("ascii") except Exception: return None def _pack_events(events: List[EventRecord], multimodal: bool) -> List[Dict[str, Any]]: packed: List[Dict[str, Any]] = [] for ev in events: if ev.event_type not in {"mouse_click", "text_input", "window_change"}: continue img_path = _best_image(ev.frame_paths) snapshot = _load_snapshot(ev.ui_snapshot) selector = ev.uia tree = _compress_tree(snapshot, selector) item: Dict[str, Any] = { "event_type": ev.event_type, "ts": ev.ts, "video_time_offset_ms": ev.video_time_offset_ms, "text": ev.text, "window_title": ev.window.title if ev.window else None, "window_process": ev.window.process_name if ev.window else None, "uia_selector": _selector_summary(selector), "uia_tree": tree, "frame_path": img_path, } if multimodal and img_path: b64 = _encode_image_b64(img_path) if b64: item["image_base64"] = b64 packed.append(item) return packed # ---------------- 主入口 ---------------- def infer_session( session_dir: Path, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "gpt-5.1-high", timeout: float = 120.0, retries: int = 1, ) -> DSLSpec: """读取 session 目录,返回 DSLSpec""" events = _load_events(session_dir) multimodal = api_key is not None packed = _pack_events(events, multimodal=multimodal) user_prompt = render_user_prompt(packed) client: LLMClient images_payload = [{"b64": e["image_base64"]} for e in packed if "image_base64" in e] if multimodal else None raw: str if multimodal: client = OpenAIVisionClient( api_key=api_key, base_url=base_url or "https://api.wgetai.com/v1", model=model, timeout=timeout, retries=retries, ) try: raw = client.generate(SYSTEM_PROMPT, user_prompt, images=images_payload) except Exception as exc: # noqa: BLE001 print(f"[warn] 多模态归纳失败,降级为文本-only(原因: {exc})") client = DummyLLM() raw = client.generate(SYSTEM_PROMPT, user_prompt, images=None) else: client = DummyLLM() raw = client.generate(SYSTEM_PROMPT, user_prompt, images=None) if not raw or not raw.strip(): raise RuntimeError("LLM 返回为空,无法解析为 JSON") cleaned = _strip_code_fences(raw) try: spec_dict = json.loads(cleaned) except Exception as exc: preview = cleaned[:500] raise RuntimeError(f"LLM 返回非 JSON,可见前 500 字符: {preview}") from exc spec_dict = _coerce_assertions(spec_dict) spec_dict = _normalize_steps(spec_dict) return _model_validate(DSLSpec, spec_dict) def main() -> None: parser = argparse.ArgumentParser(description="从 session 目录归纳 DSL(支持多模态)") parser.add_argument("--session-dir", type=str, required=True, help="session 目录,包含 events.jsonl / manifest.json / frames / ui_snapshots") parser.add_argument("--out", type=str, default="dsl.json", help="输出 DSL JSON 路径") parser.add_argument("--api-key", type=str, help="LLM API Key,缺省读取环境变量 OPENAI_API_KEY") parser.add_argument("--base-url", type=str, default="https://api.wgetai.com/v1", help="LLM Base URL") parser.add_argument("--model", type=str, default="gpt-5.1-high", help="LLM 模型名") parser.add_argument("--timeout", type=float, default=120.0, help="LLM 请求超时时间(秒)") parser.add_argument("--retries", type=int, default=1, help="LLM 请求重试次数(额外重试次数)") args = parser.parse_args() _load_env_file() session_dir = Path(args.session_dir) api_key = args.api_key or os.environ.get("OPENAI_API_KEY") base_url = args.base_url or os.environ.get("OPENAI_BASE_URL") spec = infer_session( session_dir, api_key=api_key, base_url=base_url, model=args.model, timeout=args.timeout, retries=args.retries, ) out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) with out_path.open("w", encoding="utf-8") as f: f.write(json.dumps(_model_dump(spec), ensure_ascii=False, indent=2)) print(f"DSL 写入: {out_path}") if __name__ == "__main__": main()