#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Python 版本的 OpenClaw Auth Runtime 基于 ~/clawd/skills/_shared/auth-runtime 的 TypeScript 实现 功能: - 使用 CLIENT_KEY 获取访问令牌 - 支持令牌缓存(可配置 TTL) - 自动刷新过期令牌 - 401/403 自动重试 """ import hashlib import json import os import time from dataclasses import dataclass from pathlib import Path from typing import Any, Optional import requests # type: ignore @dataclass class EnvConfig: """环境配置""" auth_base: str = "https://api-gw-test.yuanwei-lnc.com" client_key: str = "" auth_cache_dir: str = "/tmp/skill-auth-cache" auth_min_ttl_sec: int = 60 @dataclass class SessionResponse: """会话响应""" access_token: str hook_url: Optional[str] = None hook_token: Optional[str] = None expires_in: int = 900 @dataclass class CachedTokenData: """缓存的令牌数据""" access_token: str expires_at: float @dataclass class ApiResponse: """HTTP 响应""" status: int body: str headers: dict[str, str] # 可重试的状态码 RETRYABLE_STATUS = {401, 403} # 可重试的响应体标记 RETRYABLE_BODY_MARKERS = [ 'session not found or expired', 'invalid or expired token', 'unauthorized', 'client key expired', 'client key revoked', ] def create_env_config() -> EnvConfig: """ 从环境变量创建配置 环境变量: - AUTH_BASE: 认证基础 URL(默认:https://api-gw-test.yuanwei-lnc.com) - CLIENT_KEY: 客户端密钥(必需) - AUTH_CACHE_DIR: 缓存目录(默认:/tmp/skill-auth-cache) - AUTH_MIN_TTL_SEC: 最小令牌 TTL 秒数(默认:60) """ auth_base = os.getenv("AUTH_BASE", "https://api-gw-test.yuanwei-lnc.com").rstrip("/") client_key = os.getenv("CLIENT_KEY", "") auth_cache_dir = os.getenv("AUTH_CACHE_DIR", "/tmp/skill-auth-cache") auth_min_ttl_sec = int(os.getenv("AUTH_MIN_TTL_SEC", "60")) return EnvConfig( auth_base=auth_base, client_key=client_key, auth_cache_dir=auth_cache_dir, auth_min_ttl_sec=auth_min_ttl_sec, ) def get_cache_file(auth_base: str, client_key: str, cache_dir: str) -> Path: """ 获取缓存文件路径 使用 auth_base 和 client_key 的哈希值生成唯一的缓存文件名 """ cache_path = Path(cache_dir) cache_path.mkdir(parents=True, exist_ok=True) # 生成缓存文件名的哈希 key_str = f"{auth_base}:{client_key}" hash_value = hashlib.sha256(key_str.encode()).hexdigest()[:16] return cache_path / f"token_{hash_value}.json" def read_cached_token(cache_file: Path, min_ttl_sec: int) -> Optional[str]: """ 读取缓存的令牌 如果令牌存在且剩余 TTL 大于最小 TTL,则返回令牌 """ if not cache_file.exists(): return None try: with open(cache_file, "r", encoding="utf-8") as f: data = json.load(f) cached = CachedTokenData( access_token=data["access_token"], expires_at=data["expires_at"], ) # 检查是否过期(考虑最小 TTL) if time.time() + min_ttl_sec < cached.expires_at: return cached.access_token else: # 令牌已过期或即将过期,删除缓存 cache_file.unlink(missing_ok=True) return None except (json.JSONDecodeError, KeyError, Exception): # 缓存损坏,删除并重新获取 cache_file.unlink(missing_ok=True) return None def write_cache(cache_file: Path, session: SessionResponse) -> None: """ 写入缓存 缓存令牌和过期时间 """ cache_file.parent.mkdir(parents=True, exist_ok=True) data = { "access_token": session.access_token, "expires_at": time.time() + session.expires_in, "expires_in": session.expires_in, } with open(cache_file, "w", encoding="utf-8") as f: json.dump(data, f, indent=2) def delete_cache(cache_file: Path) -> None: """删除缓存""" cache_file.unlink(missing_ok=True) def fetch_session_json(dry_run: bool, config: EnvConfig) -> SessionResponse: """ 从认证端点获取会话 JSON 使用 CLIENT_KEY 请求访问令牌 """ if dry_run: return SessionResponse( access_token="", hook_url="", hook_token="", expires_in=900, ) if not config.client_key: raise ValueError("CLIENT_KEY is required") # 请求认证端点 payload = {"clientKey": config.client_key} url = f"{config.auth_base}/auth/skill-credit/session" response = requests.post(url, json=payload, timeout=30) if response.status_code < 200 or response.status_code >= 300: raise RuntimeError( f"Auth session request failed: HTTP {response.status_code} - {response.text}" ) session_data = response.json() if not session_data.get("accessToken"): raise RuntimeError(f"Missing accessToken in session response: {response.text}") return SessionResponse( access_token=session_data["accessToken"], hook_url=session_data.get("hookUrl"), hook_token=session_data.get("hookToken"), expires_in=session_data.get("expiresIn", 900), ) def get_access_token(dry_run: bool, config: EnvConfig) -> str: """ 获取访问令牌(带缓存) 1. 检查缓存 2. 如果缓存有效,返回缓存的令牌 3. 否则请求新令牌并缓存 """ if dry_run: return "" if not config.client_key: raise ValueError("CLIENT_KEY is required") cache_file = get_cache_file(config.auth_base, config.client_key, config.auth_cache_dir) cached_token = read_cached_token(cache_file, config.auth_min_ttl_sec) if cached_token: print(f"✓ 使用缓存的访问令牌") return cached_token print(f"🔑 请求新的访问令牌...") session = fetch_session_json(dry_run, config) write_cache(cache_file, session) print(f"✓ 令牌已缓存到:{cache_file}") return session.access_token def refresh_access_token(dry_run: bool, config: EnvConfig) -> str: """ 刷新访问令牌(绕过缓存) 删除缓存并重新请求新令牌 """ if dry_run: return "" if not config.client_key: raise ValueError("CLIENT_KEY is required") cache_file = get_cache_file(config.auth_base, config.client_key, config.auth_cache_dir) delete_cache(cache_file) print(f"🔄 刷新访问令牌...") return get_access_token(dry_run, config) def is_retryable_session_error(response: ApiResponse) -> bool: """ 检查响应是否表示会话过期/无效 如果是 401/403 或响应体包含特定标记,则可以重试 """ if response.status not in RETRYABLE_STATUS: return False body = (response.body or "").lower() if not body: return True return any(marker in body for marker in RETRYABLE_BODY_MARKERS) def request_api( method: str, url: str, auth_token: Optional[str] = None, body: Optional[dict[str, Any]] = None, timeout: int = 30, ) -> ApiResponse: """ 发送 HTTP 请求 Args: method: HTTP 方法(GET, POST, PUT, DELETE 等) url: 请求 URL auth_token: 访问令牌(可选) body: 请求体(可选,自动序列化为 JSON) timeout: 超时时间(秒) Returns: ApiResponse: 响应对象 """ headers = {"Content-Type": "application/json"} if auth_token: headers["Authorization"] = f"Bearer {auth_token}" try: if method.upper() == "GET": response = requests.get(url, headers=headers, timeout=timeout) elif method.upper() == "POST": response = requests.post( url, headers=headers, json=body, timeout=timeout ) elif method.upper() == "PUT": response = requests.put( url, headers=headers, json=body, timeout=timeout ) elif method.upper() == "DELETE": response = requests.delete(url, headers=headers, timeout=timeout) else: raise ValueError(f"Unsupported HTTP method: {method}") return ApiResponse( status=response.status_code, body=response.text, headers=dict(response.headers), ) except requests.exceptions.RequestException as e: return ApiResponse( status=0, body=str(e), headers={}, ) def request_api_with_auto_refresh( method: str, url: str, dry_run: bool, config: EnvConfig, body: Optional[dict[str, Any]] = None, access_token: Optional[str] = None, ) -> ApiResponse: """ 发送 API 请求并自动刷新令牌 1. 使用当前令牌发送请求 2. 如果是 401/403 错误,刷新令牌并重试一次 """ token = access_token or get_access_token(dry_run, config) # 第一次请求 first = request_api(method, url, token, body) # 检查是否需要重试 if not is_retryable_session_error(first): return first # 刷新令牌并重试 print(f"⚠️ 检测到会话过期,刷新令牌后重试...") fresh_token = refresh_access_token(dry_run, config) return request_api(method, url, fresh_token, body) # 便捷函数:从 .env 文件加载 CLIENT_KEY def load_client_key_from_env(env_path: Optional[Path] = None) -> str: """ 从 .env 文件加载 CLIENT_KEY 优先级: 1. 环境变量 CLIENT_KEY 2. .env 文件中的 CLIENT_KEY 3. 抛出异常 """ # 先检查环境变量 client_key = os.getenv("CLIENT_KEY") if client_key: return client_key # 查找 .env 文件 if env_path is None: possible_paths = [ Path(".env"), Path.cwd() / ".env", Path(__file__).parent.parent / ".env", ] for p in possible_paths: if p.exists(): env_path = p break if env_path and env_path.exists(): with open(env_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line or line.startswith("#"): continue if "=" in line: key, value = line.split("=", 1) if key.strip() == "CLIENT_KEY": return value.strip() raise ValueError( "CLIENT_KEY not found. Please set CLIENT_KEY environment variable " "or add it to .env file" )