kboss/b/cntoai/test_demo.py
2026-05-23 14:22:30 +08:00

248 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
chat_send_stream.dspy SSE 流式接口测试
用法:
pip install requests
python test_demo.py
环境变量(可选,覆盖下方默认值):
CNTOAI_BASE_URL / CNTOAI_USERID / CNTOAI_API_KEY
CNTOAI_MODEL / CNTOAI_LLM_API_URL / CNTOAI_MESSAGE
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from typing import Any, Dict, Generator, List, Optional
try:
import requests
except ImportError:
print("请先安装: pip install requests")
sys.exit(1)
BASE_URL = os.environ.get("CNTOAI_BASE_URL", "https://dev.opencomputing.cn").rstrip("/")
USERID = os.environ.get("CNTOAI_USERID", "hSqZuekZ1yKmhKmCN9UAK").strip()
API_KEY = os.environ.get("CNTOAI_API_KEY", "sk-c22d6573e85a4d3fa8ab932386cf2909").strip()
API_URL = os.environ.get("CNTOAI_LLM_API_URL", "https://api.deepseek.com/v1/chat/completions").strip()
MODEL = os.environ.get("CNTOAI_MODEL", "deepseek-chat").strip()
MESSAGE = os.environ.get("CNTOAI_MESSAGE", "你好,请用三句话介绍你自己").strip()
TIMEOUT = int(os.environ.get("CNTOAI_TIMEOUT", "300"))
STREAM_PATH = "/cntoai/chat_send_stream.dspy"
def build_payload(session_id: Optional[str] = None, message: Optional[str] = None) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"model": MODEL,
"message": message or MESSAGE,
"userid": USERID,
"api_key": API_KEY,
"api_url": API_URL,
}
if session_id:
payload["session_id"] = session_id
return payload
def parse_sse_text(text: str) -> List[Dict[str, Any]]:
events: List[Dict[str, Any]] = []
for line in text.splitlines():
line = line.strip()
if not line.startswith("data:"):
continue
data = line[5:].strip()
if data == "[DONE]":
break
try:
events.append(json.loads(data))
except json.JSONDecodeError:
print(f"[warn] 无法解析: {line[:200]}")
return events
def parse_sse_stream(response: requests.Response) -> Generator[Dict[str, Any], None, None]:
"""
按字节缓冲解析 SSE。
勿用 iter_lines(decode_unicode=True)TCP 分块可能截断 UTF-8 多字节字符,导致乱码和 JSON 解析失败。
"""
buffer = b""
for chunk in response.iter_content(chunk_size=4096):
if not chunk:
continue
buffer += chunk
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
if not line_bytes.strip():
continue
line = line_bytes.decode("utf-8").strip()
if not line.startswith("data:"):
continue
data = line[5:].strip()
if data == "[DONE]":
return
try:
yield json.loads(data)
except json.JSONDecodeError:
print(f"\n[warn] JSON 解析失败: {line[:120]}...")
tail = buffer.strip()
if tail:
line = tail.decode("utf-8", errors="replace").strip()
if line.startswith("data:"):
data = line[5:].strip()
if data and data != "[DONE]":
try:
yield json.loads(data)
except json.JSONDecodeError:
pass
def diagnose_empty_response(resp: requests.Response) -> None:
ctype = resp.headers.get("Content-Type", "")
body = resp.content or b""
print("\n[诊断] 响应体为空或无可解析 SSE")
print(f" Content-Type : {ctype}")
print(f" body 长度 : {len(body)}")
if body:
print(f" body 前 500B : {body[:500]!r}")
if "text/html" in ctype and len(body) == 0:
print("\n 可能原因: chat_send_stream.dspy 未执行 inference 入口。")
print(" 请确认文件末尾包含:")
print(" ret = await inference(request, params_kw=params_kw)")
print(" return ret")
print(" 并重新部署到 dev 后再测。")
def test_chat_send_stream(session_id: Optional[str] = None, message: Optional[str] = None) -> Optional[str]:
url = BASE_URL + STREAM_PATH
payload = build_payload(session_id=session_id, message=message)
print("=" * 60)
print("chat_send_stream.dspy 流式测试")
print(f" URL : {url}")
print(f" MODEL : {MODEL}")
print(f" USERID : {USERID}")
print(f" API_URL : {API_URL}")
print(f" message : {payload.get('message')}")
if session_id:
print(f" session : {session_id}")
print("=" * 60)
if not USERID:
print("错误: 请设置 CNTOAI_USERID")
return None
resp = requests.post(
url,
json=payload,
headers={
"Accept": "text/event-stream",
"Content-Type": "application/json",
},
stream=True,
timeout=TIMEOUT,
)
ctype = resp.headers.get("Content-Type", "")
print(f"\nHTTP {resp.status_code} Content-Type: {ctype}\n")
if resp.status_code != 200:
print(resp.text[:500])
return None
if "text/event-stream" not in ctype:
raw = resp.content
diagnose_empty_response(resp)
if raw:
for evt in parse_sse_text(raw.decode("utf-8", errors="ignore")):
print("[parsed]", evt)
return None
session_out: Optional[str] = session_id
full_reply: List[str] = []
has_content = False
event_count = 0
print("--- 流式输出 ---")
for evt in parse_sse_stream(resp):
event_count += 1
etype = evt.get("type")
if etype == "meta":
session_out = evt.get("session_id") or session_out
print(f"[meta] session_id={session_out} model={evt.get('model')}")
continue
if etype == "content":
piece = evt.get("content") or ""
has_content = True
full_reply.append(piece)
print(piece, end="", flush=True)
continue
if etype == "done":
session_out = evt.get("session_id") or session_out
reply = evt.get("reply") or ""
print(f"\n\n[done] session_id={session_out}")
print(f"[done] reply 长度={len(reply)}")
if reply and not has_content:
print(reply)
continue
if etype == "error":
print(f"\n[error] {evt.get('msg')}")
return session_out
print(f"\n[unknown] {evt}")
print("\n--- 结束 ---")
if event_count == 0:
diagnose_empty_response(resp)
elif full_reply:
joined = "".join(full_reply)
print(f"拼接回复({len(joined)}字): {joined[:300]}...")
return session_out
def main() -> int:
if sys.platform == "win32":
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
parser = argparse.ArgumentParser(description="chat_send_stream.dspy SSE 测试")
parser.add_argument("--session-id", help="续聊会话 ID")
parser.add_argument("--message", "-m", help="覆盖默认 message")
parser.add_argument("--twice", action="store_true", help="同一会话连发两条")
args = parser.parse_args()
sid = test_chat_send_stream(session_id=args.session_id, message=args.message)
if sid is None:
return 1
if args.twice and sid:
print("\n" + "=" * 60)
print("第二轮(多轮续聊)")
sid2 = test_chat_send_stream(
session_id=sid,
message=args.message or "继续,用一句话总结上面内容",
)
if sid2 is None:
return 1
if sid:
print(f"\n提示: 续聊 python test_demo.py --session-id {sid}")
return 0
if __name__ == "__main__":
sys.exit(main())