248 lines
7.6 KiB
Python
248 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
chat_send_stream.dspy SSE 流式接口测试
|
||
|
||
用法:
|
||
pip install requests
|
||
python test_demo.py
|
||
|
||
环境变量(可选,覆盖下方默认值):
|
||
CNTOAI_BASE_URL / CNTOAI_USERID / CNTOAI_API_KEY
|
||
CNTOAI_MODEL / CNTOAI_LLM_API_URL / CNTOAI_MESSAGE
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import sys
|
||
from typing import Any, Dict, Generator, List, Optional
|
||
|
||
try:
|
||
import requests
|
||
except ImportError:
|
||
print("请先安装: pip install requests")
|
||
sys.exit(1)
|
||
|
||
|
||
BASE_URL = os.environ.get("CNTOAI_BASE_URL", "https://dev.opencomputing.cn").rstrip("/")
|
||
USERID = os.environ.get("CNTOAI_USERID", "hSqZuekZ1yKmhKmCN9UAK").strip()
|
||
API_KEY = os.environ.get("CNTOAI_API_KEY", "sk-c22d6573e85a4d3fa8ab932386cf2909").strip()
|
||
API_URL = os.environ.get("CNTOAI_LLM_API_URL", "https://api.deepseek.com/v1/chat/completions").strip()
|
||
MODEL = os.environ.get("CNTOAI_MODEL", "deepseek-chat").strip()
|
||
MESSAGE = os.environ.get("CNTOAI_MESSAGE", "你好,请用三句话介绍你自己").strip()
|
||
TIMEOUT = int(os.environ.get("CNTOAI_TIMEOUT", "300"))
|
||
|
||
STREAM_PATH = "/cntoai/chat_send_stream.dspy"
|
||
|
||
|
||
def build_payload(session_id: Optional[str] = None, message: Optional[str] = None) -> Dict[str, Any]:
|
||
payload: Dict[str, Any] = {
|
||
"model": MODEL,
|
||
"message": message or MESSAGE,
|
||
"userid": USERID,
|
||
"api_key": API_KEY,
|
||
"api_url": API_URL,
|
||
}
|
||
if session_id:
|
||
payload["session_id"] = session_id
|
||
return payload
|
||
|
||
|
||
def parse_sse_text(text: str) -> List[Dict[str, Any]]:
|
||
events: List[Dict[str, Any]] = []
|
||
for line in text.splitlines():
|
||
line = line.strip()
|
||
if not line.startswith("data:"):
|
||
continue
|
||
data = line[5:].strip()
|
||
if data == "[DONE]":
|
||
break
|
||
try:
|
||
events.append(json.loads(data))
|
||
except json.JSONDecodeError:
|
||
print(f"[warn] 无法解析: {line[:200]}")
|
||
return events
|
||
|
||
|
||
def parse_sse_stream(response: requests.Response) -> Generator[Dict[str, Any], None, None]:
|
||
"""
|
||
按字节缓冲解析 SSE。
|
||
勿用 iter_lines(decode_unicode=True):TCP 分块可能截断 UTF-8 多字节字符,导致乱码和 JSON 解析失败。
|
||
"""
|
||
buffer = b""
|
||
for chunk in response.iter_content(chunk_size=4096):
|
||
if not chunk:
|
||
continue
|
||
buffer += chunk
|
||
while b"\n" in buffer:
|
||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||
if not line_bytes.strip():
|
||
continue
|
||
line = line_bytes.decode("utf-8").strip()
|
||
if not line.startswith("data:"):
|
||
continue
|
||
data = line[5:].strip()
|
||
if data == "[DONE]":
|
||
return
|
||
try:
|
||
yield json.loads(data)
|
||
except json.JSONDecodeError:
|
||
print(f"\n[warn] JSON 解析失败: {line[:120]}...")
|
||
|
||
tail = buffer.strip()
|
||
if tail:
|
||
line = tail.decode("utf-8", errors="replace").strip()
|
||
if line.startswith("data:"):
|
||
data = line[5:].strip()
|
||
if data and data != "[DONE]":
|
||
try:
|
||
yield json.loads(data)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
|
||
def diagnose_empty_response(resp: requests.Response) -> None:
|
||
ctype = resp.headers.get("Content-Type", "")
|
||
body = resp.content or b""
|
||
print("\n[诊断] 响应体为空或无可解析 SSE")
|
||
print(f" Content-Type : {ctype}")
|
||
print(f" body 长度 : {len(body)}")
|
||
if body:
|
||
print(f" body 前 500B : {body[:500]!r}")
|
||
if "text/html" in ctype and len(body) == 0:
|
||
print("\n 可能原因: chat_send_stream.dspy 未执行 inference 入口。")
|
||
print(" 请确认文件末尾包含:")
|
||
print(" ret = await inference(request, params_kw=params_kw)")
|
||
print(" return ret")
|
||
print(" 并重新部署到 dev 后再测。")
|
||
|
||
|
||
def test_chat_send_stream(session_id: Optional[str] = None, message: Optional[str] = None) -> Optional[str]:
|
||
url = BASE_URL + STREAM_PATH
|
||
payload = build_payload(session_id=session_id, message=message)
|
||
|
||
print("=" * 60)
|
||
print("chat_send_stream.dspy 流式测试")
|
||
print(f" URL : {url}")
|
||
print(f" MODEL : {MODEL}")
|
||
print(f" USERID : {USERID}")
|
||
print(f" API_URL : {API_URL}")
|
||
print(f" message : {payload.get('message')}")
|
||
if session_id:
|
||
print(f" session : {session_id}")
|
||
print("=" * 60)
|
||
|
||
if not USERID:
|
||
print("错误: 请设置 CNTOAI_USERID")
|
||
return None
|
||
|
||
resp = requests.post(
|
||
url,
|
||
json=payload,
|
||
headers={
|
||
"Accept": "text/event-stream",
|
||
"Content-Type": "application/json",
|
||
},
|
||
stream=True,
|
||
timeout=TIMEOUT,
|
||
)
|
||
|
||
ctype = resp.headers.get("Content-Type", "")
|
||
print(f"\nHTTP {resp.status_code} Content-Type: {ctype}\n")
|
||
|
||
if resp.status_code != 200:
|
||
print(resp.text[:500])
|
||
return None
|
||
|
||
if "text/event-stream" not in ctype:
|
||
raw = resp.content
|
||
diagnose_empty_response(resp)
|
||
if raw:
|
||
for evt in parse_sse_text(raw.decode("utf-8", errors="ignore")):
|
||
print("[parsed]", evt)
|
||
return None
|
||
|
||
session_out: Optional[str] = session_id
|
||
full_reply: List[str] = []
|
||
has_content = False
|
||
event_count = 0
|
||
|
||
print("--- 流式输出 ---")
|
||
for evt in parse_sse_stream(resp):
|
||
event_count += 1
|
||
etype = evt.get("type")
|
||
|
||
if etype == "meta":
|
||
session_out = evt.get("session_id") or session_out
|
||
print(f"[meta] session_id={session_out} model={evt.get('model')}")
|
||
continue
|
||
|
||
if etype == "content":
|
||
piece = evt.get("content") or ""
|
||
has_content = True
|
||
full_reply.append(piece)
|
||
print(piece, end="", flush=True)
|
||
continue
|
||
|
||
if etype == "done":
|
||
session_out = evt.get("session_id") or session_out
|
||
reply = evt.get("reply") or ""
|
||
print(f"\n\n[done] session_id={session_out}")
|
||
print(f"[done] reply 长度={len(reply)}")
|
||
if reply and not has_content:
|
||
print(reply)
|
||
continue
|
||
|
||
if etype == "error":
|
||
print(f"\n[error] {evt.get('msg')}")
|
||
return session_out
|
||
|
||
print(f"\n[unknown] {evt}")
|
||
|
||
print("\n--- 结束 ---")
|
||
if event_count == 0:
|
||
diagnose_empty_response(resp)
|
||
elif full_reply:
|
||
joined = "".join(full_reply)
|
||
print(f"拼接回复({len(joined)}字): {joined[:300]}...")
|
||
return session_out
|
||
|
||
|
||
def main() -> int:
|
||
if sys.platform == "win32":
|
||
try:
|
||
sys.stdout.reconfigure(encoding="utf-8")
|
||
except Exception:
|
||
pass
|
||
|
||
parser = argparse.ArgumentParser(description="chat_send_stream.dspy SSE 测试")
|
||
parser.add_argument("--session-id", help="续聊会话 ID")
|
||
parser.add_argument("--message", "-m", help="覆盖默认 message")
|
||
parser.add_argument("--twice", action="store_true", help="同一会话连发两条")
|
||
args = parser.parse_args()
|
||
|
||
sid = test_chat_send_stream(session_id=args.session_id, message=args.message)
|
||
if sid is None:
|
||
return 1
|
||
|
||
if args.twice and sid:
|
||
print("\n" + "=" * 60)
|
||
print("第二轮(多轮续聊)")
|
||
sid2 = test_chat_send_stream(
|
||
session_id=sid,
|
||
message=args.message or "继续,用一句话总结上面内容",
|
||
)
|
||
if sid2 is None:
|
||
return 1
|
||
|
||
if sid:
|
||
print(f"\n提示: 续聊 python test_demo.py --session-id {sid}")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|