396 lines
11 KiB
Python
396 lines
11 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
Python 版本的 OpenClaw Auth Runtime
|
|||
|
|
基于 ~/clawd/skills/_shared/auth-runtime 的 TypeScript 实现
|
|||
|
|
|
|||
|
|
功能:
|
|||
|
|
- 使用 CLIENT_KEY 获取访问令牌
|
|||
|
|
- 支持令牌缓存(可配置 TTL)
|
|||
|
|
- 自动刷新过期令牌
|
|||
|
|
- 401/403 自动重试
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import hashlib
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any, Optional
|
|||
|
|
|
|||
|
|
import requests # type: ignore
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class EnvConfig:
|
|||
|
|
"""环境配置"""
|
|||
|
|
auth_base: str = "https://api-gw-test.yuanwei-lnc.com"
|
|||
|
|
client_key: str = ""
|
|||
|
|
auth_cache_dir: str = "/tmp/skill-auth-cache"
|
|||
|
|
auth_min_ttl_sec: int = 60
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class SessionResponse:
|
|||
|
|
"""会话响应"""
|
|||
|
|
access_token: str
|
|||
|
|
hook_url: Optional[str] = None
|
|||
|
|
hook_token: Optional[str] = None
|
|||
|
|
expires_in: int = 900
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class CachedTokenData:
|
|||
|
|
"""缓存的令牌数据"""
|
|||
|
|
access_token: str
|
|||
|
|
expires_at: float
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ApiResponse:
|
|||
|
|
"""HTTP 响应"""
|
|||
|
|
status: int
|
|||
|
|
body: str
|
|||
|
|
headers: dict[str, str]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 可重试的状态码
|
|||
|
|
RETRYABLE_STATUS = {401, 403}
|
|||
|
|
|
|||
|
|
# 可重试的响应体标记
|
|||
|
|
RETRYABLE_BODY_MARKERS = [
|
|||
|
|
'session not found or expired',
|
|||
|
|
'invalid or expired token',
|
|||
|
|
'unauthorized',
|
|||
|
|
'client key expired',
|
|||
|
|
'client key revoked',
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_env_config() -> EnvConfig:
|
|||
|
|
"""
|
|||
|
|
从环境变量创建配置
|
|||
|
|
|
|||
|
|
环境变量:
|
|||
|
|
- AUTH_BASE: 认证基础 URL(默认:https://api-gw-test.yuanwei-lnc.com)
|
|||
|
|
- CLIENT_KEY: 客户端密钥(必需)
|
|||
|
|
- AUTH_CACHE_DIR: 缓存目录(默认:/tmp/skill-auth-cache)
|
|||
|
|
- AUTH_MIN_TTL_SEC: 最小令牌 TTL 秒数(默认:60)
|
|||
|
|
"""
|
|||
|
|
auth_base = os.getenv("AUTH_BASE", "https://api-gw-test.yuanwei-lnc.com").rstrip("/")
|
|||
|
|
client_key = os.getenv("CLIENT_KEY", "")
|
|||
|
|
auth_cache_dir = os.getenv("AUTH_CACHE_DIR", "/tmp/skill-auth-cache")
|
|||
|
|
auth_min_ttl_sec = int(os.getenv("AUTH_MIN_TTL_SEC", "60"))
|
|||
|
|
|
|||
|
|
return EnvConfig(
|
|||
|
|
auth_base=auth_base,
|
|||
|
|
client_key=client_key,
|
|||
|
|
auth_cache_dir=auth_cache_dir,
|
|||
|
|
auth_min_ttl_sec=auth_min_ttl_sec,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_cache_file(auth_base: str, client_key: str, cache_dir: str) -> Path:
|
|||
|
|
"""
|
|||
|
|
获取缓存文件路径
|
|||
|
|
|
|||
|
|
使用 auth_base 和 client_key 的哈希值生成唯一的缓存文件名
|
|||
|
|
"""
|
|||
|
|
cache_path = Path(cache_dir)
|
|||
|
|
cache_path.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 生成缓存文件名的哈希
|
|||
|
|
key_str = f"{auth_base}:{client_key}"
|
|||
|
|
hash_value = hashlib.sha256(key_str.encode()).hexdigest()[:16]
|
|||
|
|
|
|||
|
|
return cache_path / f"token_{hash_value}.json"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def read_cached_token(cache_file: Path, min_ttl_sec: int) -> Optional[str]:
|
|||
|
|
"""
|
|||
|
|
读取缓存的令牌
|
|||
|
|
|
|||
|
|
如果令牌存在且剩余 TTL 大于最小 TTL,则返回令牌
|
|||
|
|
"""
|
|||
|
|
if not cache_file.exists():
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with open(cache_file, "r", encoding="utf-8") as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
|
|||
|
|
cached = CachedTokenData(
|
|||
|
|
access_token=data["access_token"],
|
|||
|
|
expires_at=data["expires_at"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查是否过期(考虑最小 TTL)
|
|||
|
|
if time.time() + min_ttl_sec < cached.expires_at:
|
|||
|
|
return cached.access_token
|
|||
|
|
else:
|
|||
|
|
# 令牌已过期或即将过期,删除缓存
|
|||
|
|
cache_file.unlink(missing_ok=True)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
except (json.JSONDecodeError, KeyError, Exception):
|
|||
|
|
# 缓存损坏,删除并重新获取
|
|||
|
|
cache_file.unlink(missing_ok=True)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def write_cache(cache_file: Path, session: SessionResponse) -> None:
|
|||
|
|
"""
|
|||
|
|
写入缓存
|
|||
|
|
|
|||
|
|
缓存令牌和过期时间
|
|||
|
|
"""
|
|||
|
|
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
data = {
|
|||
|
|
"access_token": session.access_token,
|
|||
|
|
"expires_at": time.time() + session.expires_in,
|
|||
|
|
"expires_in": session.expires_in,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with open(cache_file, "w", encoding="utf-8") as f:
|
|||
|
|
json.dump(data, f, indent=2)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def delete_cache(cache_file: Path) -> None:
|
|||
|
|
"""删除缓存"""
|
|||
|
|
cache_file.unlink(missing_ok=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def fetch_session_json(dry_run: bool, config: EnvConfig) -> SessionResponse:
|
|||
|
|
"""
|
|||
|
|
从认证端点获取会话 JSON
|
|||
|
|
|
|||
|
|
使用 CLIENT_KEY 请求访问令牌
|
|||
|
|
"""
|
|||
|
|
if dry_run:
|
|||
|
|
return SessionResponse(
|
|||
|
|
access_token="<dry-run-token>",
|
|||
|
|
hook_url="<dry-run-hook-url>",
|
|||
|
|
hook_token="<dry-run-hook-token>",
|
|||
|
|
expires_in=900,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not config.client_key:
|
|||
|
|
raise ValueError("CLIENT_KEY is required")
|
|||
|
|
|
|||
|
|
# 请求认证端点
|
|||
|
|
payload = {"clientKey": config.client_key}
|
|||
|
|
url = f"{config.auth_base}/auth/skill-credit/session"
|
|||
|
|
|
|||
|
|
response = requests.post(url, json=payload, timeout=30)
|
|||
|
|
|
|||
|
|
if response.status_code < 200 or response.status_code >= 300:
|
|||
|
|
raise RuntimeError(
|
|||
|
|
f"Auth session request failed: HTTP {response.status_code} - {response.text}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
session_data = response.json()
|
|||
|
|
|
|||
|
|
if not session_data.get("accessToken"):
|
|||
|
|
raise RuntimeError(f"Missing accessToken in session response: {response.text}")
|
|||
|
|
|
|||
|
|
return SessionResponse(
|
|||
|
|
access_token=session_data["accessToken"],
|
|||
|
|
hook_url=session_data.get("hookUrl"),
|
|||
|
|
hook_token=session_data.get("hookToken"),
|
|||
|
|
expires_in=session_data.get("expiresIn", 900),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_access_token(dry_run: bool, config: EnvConfig) -> str:
|
|||
|
|
"""
|
|||
|
|
获取访问令牌(带缓存)
|
|||
|
|
|
|||
|
|
1. 检查缓存
|
|||
|
|
2. 如果缓存有效,返回缓存的令牌
|
|||
|
|
3. 否则请求新令牌并缓存
|
|||
|
|
"""
|
|||
|
|
if dry_run:
|
|||
|
|
return "<dry-run-token>"
|
|||
|
|
|
|||
|
|
if not config.client_key:
|
|||
|
|
raise ValueError("CLIENT_KEY is required")
|
|||
|
|
|
|||
|
|
cache_file = get_cache_file(config.auth_base, config.client_key, config.auth_cache_dir)
|
|||
|
|
cached_token = read_cached_token(cache_file, config.auth_min_ttl_sec)
|
|||
|
|
|
|||
|
|
if cached_token:
|
|||
|
|
print(f"✓ 使用缓存的访问令牌")
|
|||
|
|
return cached_token
|
|||
|
|
|
|||
|
|
print(f"🔑 请求新的访问令牌...")
|
|||
|
|
session = fetch_session_json(dry_run, config)
|
|||
|
|
write_cache(cache_file, session)
|
|||
|
|
print(f"✓ 令牌已缓存到:{cache_file}")
|
|||
|
|
|
|||
|
|
return session.access_token
|
|||
|
|
|
|||
|
|
|
|||
|
|
def refresh_access_token(dry_run: bool, config: EnvConfig) -> str:
|
|||
|
|
"""
|
|||
|
|
刷新访问令牌(绕过缓存)
|
|||
|
|
|
|||
|
|
删除缓存并重新请求新令牌
|
|||
|
|
"""
|
|||
|
|
if dry_run:
|
|||
|
|
return "<dry-run-token>"
|
|||
|
|
|
|||
|
|
if not config.client_key:
|
|||
|
|
raise ValueError("CLIENT_KEY is required")
|
|||
|
|
|
|||
|
|
cache_file = get_cache_file(config.auth_base, config.client_key, config.auth_cache_dir)
|
|||
|
|
delete_cache(cache_file)
|
|||
|
|
print(f"🔄 刷新访问令牌...")
|
|||
|
|
|
|||
|
|
return get_access_token(dry_run, config)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def is_retryable_session_error(response: ApiResponse) -> bool:
|
|||
|
|
"""
|
|||
|
|
检查响应是否表示会话过期/无效
|
|||
|
|
|
|||
|
|
如果是 401/403 或响应体包含特定标记,则可以重试
|
|||
|
|
"""
|
|||
|
|
if response.status not in RETRYABLE_STATUS:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
body = (response.body or "").lower()
|
|||
|
|
if not body:
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
return any(marker in body for marker in RETRYABLE_BODY_MARKERS)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def request_api(
|
|||
|
|
method: str,
|
|||
|
|
url: str,
|
|||
|
|
auth_token: Optional[str] = None,
|
|||
|
|
body: Optional[dict[str, Any]] = None,
|
|||
|
|
timeout: int = 30,
|
|||
|
|
) -> ApiResponse:
|
|||
|
|
"""
|
|||
|
|
发送 HTTP 请求
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
method: HTTP 方法(GET, POST, PUT, DELETE 等)
|
|||
|
|
url: 请求 URL
|
|||
|
|
auth_token: 访问令牌(可选)
|
|||
|
|
body: 请求体(可选,自动序列化为 JSON)
|
|||
|
|
timeout: 超时时间(秒)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
ApiResponse: 响应对象
|
|||
|
|
"""
|
|||
|
|
headers = {"Content-Type": "application/json"}
|
|||
|
|
|
|||
|
|
if auth_token:
|
|||
|
|
headers["Authorization"] = f"Bearer {auth_token}"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if method.upper() == "GET":
|
|||
|
|
response = requests.get(url, headers=headers, timeout=timeout)
|
|||
|
|
elif method.upper() == "POST":
|
|||
|
|
response = requests.post(
|
|||
|
|
url, headers=headers, json=body, timeout=timeout
|
|||
|
|
)
|
|||
|
|
elif method.upper() == "PUT":
|
|||
|
|
response = requests.put(
|
|||
|
|
url, headers=headers, json=body, timeout=timeout
|
|||
|
|
)
|
|||
|
|
elif method.upper() == "DELETE":
|
|||
|
|
response = requests.delete(url, headers=headers, timeout=timeout)
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
|||
|
|
|
|||
|
|
return ApiResponse(
|
|||
|
|
status=response.status_code,
|
|||
|
|
body=response.text,
|
|||
|
|
headers=dict(response.headers),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except requests.exceptions.RequestException as e:
|
|||
|
|
return ApiResponse(
|
|||
|
|
status=0,
|
|||
|
|
body=str(e),
|
|||
|
|
headers={},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def request_api_with_auto_refresh(
|
|||
|
|
method: str,
|
|||
|
|
url: str,
|
|||
|
|
dry_run: bool,
|
|||
|
|
config: EnvConfig,
|
|||
|
|
body: Optional[dict[str, Any]] = None,
|
|||
|
|
access_token: Optional[str] = None,
|
|||
|
|
) -> ApiResponse:
|
|||
|
|
"""
|
|||
|
|
发送 API 请求并自动刷新令牌
|
|||
|
|
|
|||
|
|
1. 使用当前令牌发送请求
|
|||
|
|
2. 如果是 401/403 错误,刷新令牌并重试一次
|
|||
|
|
"""
|
|||
|
|
token = access_token or get_access_token(dry_run, config)
|
|||
|
|
|
|||
|
|
# 第一次请求
|
|||
|
|
first = request_api(method, url, token, body)
|
|||
|
|
|
|||
|
|
# 检查是否需要重试
|
|||
|
|
if not is_retryable_session_error(first):
|
|||
|
|
return first
|
|||
|
|
|
|||
|
|
# 刷新令牌并重试
|
|||
|
|
print(f"⚠️ 检测到会话过期,刷新令牌后重试...")
|
|||
|
|
fresh_token = refresh_access_token(dry_run, config)
|
|||
|
|
return request_api(method, url, fresh_token, body)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 便捷函数:从 .env 文件加载 CLIENT_KEY
|
|||
|
|
def load_client_key_from_env(env_path: Optional[Path] = None) -> str:
|
|||
|
|
"""
|
|||
|
|
从 .env 文件加载 CLIENT_KEY
|
|||
|
|
|
|||
|
|
优先级:
|
|||
|
|
1. 环境变量 CLIENT_KEY
|
|||
|
|
2. .env 文件中的 CLIENT_KEY
|
|||
|
|
3. 抛出异常
|
|||
|
|
"""
|
|||
|
|
# 先检查环境变量
|
|||
|
|
client_key = os.getenv("CLIENT_KEY")
|
|||
|
|
if client_key:
|
|||
|
|
return client_key
|
|||
|
|
|
|||
|
|
# 查找 .env 文件
|
|||
|
|
if env_path is None:
|
|||
|
|
possible_paths = [
|
|||
|
|
Path(".env"),
|
|||
|
|
Path.cwd() / ".env",
|
|||
|
|
Path(__file__).parent.parent / ".env",
|
|||
|
|
]
|
|||
|
|
for p in possible_paths:
|
|||
|
|
if p.exists():
|
|||
|
|
env_path = p
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if env_path and env_path.exists():
|
|||
|
|
with open(env_path, "r", encoding="utf-8") as f:
|
|||
|
|
for line in f:
|
|||
|
|
line = line.strip()
|
|||
|
|
if not line or line.startswith("#"):
|
|||
|
|
continue
|
|||
|
|
if "=" in line:
|
|||
|
|
key, value = line.split("=", 1)
|
|||
|
|
if key.strip() == "CLIENT_KEY":
|
|||
|
|
return value.strip()
|
|||
|
|
|
|||
|
|
raise ValueError(
|
|||
|
|
"CLIENT_KEY not found. Please set CLIENT_KEY environment variable "
|
|||
|
|
"or add it to .env file"
|
|||
|
|
)
|