excel-toolkit/scripts/auth_runtime.py

396 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Python 版本的 OpenClaw Auth Runtime
基于 ~/clawd/skills/_shared/auth-runtime 的 TypeScript 实现
功能:
- 使用 CLIENT_KEY 获取访问令牌
- 支持令牌缓存(可配置 TTL
- 自动刷新过期令牌
- 401/403 自动重试
"""
import hashlib
import json
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import requests # type: ignore
@dataclass
class EnvConfig:
"""环境配置"""
auth_base: str = "https://api-gw-test.yuanwei-lnc.com"
client_key: str = ""
auth_cache_dir: str = "/tmp/skill-auth-cache"
auth_min_ttl_sec: int = 60
@dataclass
class SessionResponse:
"""会话响应"""
access_token: str
hook_url: Optional[str] = None
hook_token: Optional[str] = None
expires_in: int = 900
@dataclass
class CachedTokenData:
"""缓存的令牌数据"""
access_token: str
expires_at: float
@dataclass
class ApiResponse:
"""HTTP 响应"""
status: int
body: str
headers: dict[str, str]
# 可重试的状态码
RETRYABLE_STATUS = {401, 403}
# 可重试的响应体标记
RETRYABLE_BODY_MARKERS = [
'session not found or expired',
'invalid or expired token',
'unauthorized',
'client key expired',
'client key revoked',
]
def create_env_config() -> EnvConfig:
"""
从环境变量创建配置
环境变量:
- AUTH_BASE: 认证基础 URL默认https://api-gw-test.yuanwei-lnc.com
- CLIENT_KEY: 客户端密钥(必需)
- AUTH_CACHE_DIR: 缓存目录(默认:/tmp/skill-auth-cache
- AUTH_MIN_TTL_SEC: 最小令牌 TTL 秒数默认60
"""
auth_base = os.getenv("AUTH_BASE", "https://api-gw-test.yuanwei-lnc.com").rstrip("/")
client_key = os.getenv("CLIENT_KEY", "")
auth_cache_dir = os.getenv("AUTH_CACHE_DIR", "/tmp/skill-auth-cache")
auth_min_ttl_sec = int(os.getenv("AUTH_MIN_TTL_SEC", "60"))
return EnvConfig(
auth_base=auth_base,
client_key=client_key,
auth_cache_dir=auth_cache_dir,
auth_min_ttl_sec=auth_min_ttl_sec,
)
def get_cache_file(auth_base: str, client_key: str, cache_dir: str) -> Path:
"""
获取缓存文件路径
使用 auth_base 和 client_key 的哈希值生成唯一的缓存文件名
"""
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
# 生成缓存文件名的哈希
key_str = f"{auth_base}:{client_key}"
hash_value = hashlib.sha256(key_str.encode()).hexdigest()[:16]
return cache_path / f"token_{hash_value}.json"
def read_cached_token(cache_file: Path, min_ttl_sec: int) -> Optional[str]:
"""
读取缓存的令牌
如果令牌存在且剩余 TTL 大于最小 TTL则返回令牌
"""
if not cache_file.exists():
return None
try:
with open(cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
cached = CachedTokenData(
access_token=data["access_token"],
expires_at=data["expires_at"],
)
# 检查是否过期(考虑最小 TTL
if time.time() + min_ttl_sec < cached.expires_at:
return cached.access_token
else:
# 令牌已过期或即将过期,删除缓存
cache_file.unlink(missing_ok=True)
return None
except (json.JSONDecodeError, KeyError, Exception):
# 缓存损坏,删除并重新获取
cache_file.unlink(missing_ok=True)
return None
def write_cache(cache_file: Path, session: SessionResponse) -> None:
"""
写入缓存
缓存令牌和过期时间
"""
cache_file.parent.mkdir(parents=True, exist_ok=True)
data = {
"access_token": session.access_token,
"expires_at": time.time() + session.expires_in,
"expires_in": session.expires_in,
}
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
def delete_cache(cache_file: Path) -> None:
"""删除缓存"""
cache_file.unlink(missing_ok=True)
def fetch_session_json(dry_run: bool, config: EnvConfig) -> SessionResponse:
"""
从认证端点获取会话 JSON
使用 CLIENT_KEY 请求访问令牌
"""
if dry_run:
return SessionResponse(
access_token="<dry-run-token>",
hook_url="<dry-run-hook-url>",
hook_token="<dry-run-hook-token>",
expires_in=900,
)
if not config.client_key:
raise ValueError("CLIENT_KEY is required")
# 请求认证端点
payload = {"clientKey": config.client_key}
url = f"{config.auth_base}/auth/skill-credit/session"
response = requests.post(url, json=payload, timeout=30)
if response.status_code < 200 or response.status_code >= 300:
raise RuntimeError(
f"Auth session request failed: HTTP {response.status_code} - {response.text}"
)
session_data = response.json()
if not session_data.get("accessToken"):
raise RuntimeError(f"Missing accessToken in session response: {response.text}")
return SessionResponse(
access_token=session_data["accessToken"],
hook_url=session_data.get("hookUrl"),
hook_token=session_data.get("hookToken"),
expires_in=session_data.get("expiresIn", 900),
)
def get_access_token(dry_run: bool, config: EnvConfig) -> str:
"""
获取访问令牌(带缓存)
1. 检查缓存
2. 如果缓存有效,返回缓存的令牌
3. 否则请求新令牌并缓存
"""
if dry_run:
return "<dry-run-token>"
if not config.client_key:
raise ValueError("CLIENT_KEY is required")
cache_file = get_cache_file(config.auth_base, config.client_key, config.auth_cache_dir)
cached_token = read_cached_token(cache_file, config.auth_min_ttl_sec)
if cached_token:
print(f"✓ 使用缓存的访问令牌")
return cached_token
print(f"🔑 请求新的访问令牌...")
session = fetch_session_json(dry_run, config)
write_cache(cache_file, session)
print(f"✓ 令牌已缓存到:{cache_file}")
return session.access_token
def refresh_access_token(dry_run: bool, config: EnvConfig) -> str:
"""
刷新访问令牌(绕过缓存)
删除缓存并重新请求新令牌
"""
if dry_run:
return "<dry-run-token>"
if not config.client_key:
raise ValueError("CLIENT_KEY is required")
cache_file = get_cache_file(config.auth_base, config.client_key, config.auth_cache_dir)
delete_cache(cache_file)
print(f"🔄 刷新访问令牌...")
return get_access_token(dry_run, config)
def is_retryable_session_error(response: ApiResponse) -> bool:
"""
检查响应是否表示会话过期/无效
如果是 401/403 或响应体包含特定标记,则可以重试
"""
if response.status not in RETRYABLE_STATUS:
return False
body = (response.body or "").lower()
if not body:
return True
return any(marker in body for marker in RETRYABLE_BODY_MARKERS)
def request_api(
method: str,
url: str,
auth_token: Optional[str] = None,
body: Optional[dict[str, Any]] = None,
timeout: int = 30,
) -> ApiResponse:
"""
发送 HTTP 请求
Args:
method: HTTP 方法GET, POST, PUT, DELETE 等)
url: 请求 URL
auth_token: 访问令牌(可选)
body: 请求体(可选,自动序列化为 JSON
timeout: 超时时间(秒)
Returns:
ApiResponse: 响应对象
"""
headers = {"Content-Type": "application/json"}
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
try:
if method.upper() == "GET":
response = requests.get(url, headers=headers, timeout=timeout)
elif method.upper() == "POST":
response = requests.post(
url, headers=headers, json=body, timeout=timeout
)
elif method.upper() == "PUT":
response = requests.put(
url, headers=headers, json=body, timeout=timeout
)
elif method.upper() == "DELETE":
response = requests.delete(url, headers=headers, timeout=timeout)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
return ApiResponse(
status=response.status_code,
body=response.text,
headers=dict(response.headers),
)
except requests.exceptions.RequestException as e:
return ApiResponse(
status=0,
body=str(e),
headers={},
)
def request_api_with_auto_refresh(
method: str,
url: str,
dry_run: bool,
config: EnvConfig,
body: Optional[dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> ApiResponse:
"""
发送 API 请求并自动刷新令牌
1. 使用当前令牌发送请求
2. 如果是 401/403 错误,刷新令牌并重试一次
"""
token = access_token or get_access_token(dry_run, config)
# 第一次请求
first = request_api(method, url, token, body)
# 检查是否需要重试
if not is_retryable_session_error(first):
return first
# 刷新令牌并重试
print(f"⚠️ 检测到会话过期,刷新令牌后重试...")
fresh_token = refresh_access_token(dry_run, config)
return request_api(method, url, fresh_token, body)
# 便捷函数:从 .env 文件加载 CLIENT_KEY
def load_client_key_from_env(env_path: Optional[Path] = None) -> str:
"""
从 .env 文件加载 CLIENT_KEY
优先级:
1. 环境变量 CLIENT_KEY
2. .env 文件中的 CLIENT_KEY
3. 抛出异常
"""
# 先检查环境变量
client_key = os.getenv("CLIENT_KEY")
if client_key:
return client_key
# 查找 .env 文件
if env_path is None:
possible_paths = [
Path(".env"),
Path.cwd() / ".env",
Path(__file__).parent.parent / ".env",
]
for p in possible_paths:
if p.exists():
env_path = p
break
if env_path and env_path.exists():
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
key, value = line.split("=", 1)
if key.strip() == "CLIENT_KEY":
return value.strip()
raise ValueError(
"CLIENT_KEY not found. Please set CLIENT_KEY environment variable "
"or add it to .env file"
)