From 7bb029a66e21ad850a5909d7bc766d94f3333eb2 Mon Sep 17 00:00:00 2001 From: ping <1017253325@qq.com> Date: Sat, 23 May 2026 14:22:30 +0800 Subject: [PATCH] update --- b/cntoai/chat_send.dspy | 26 +- b/cntoai/chat_send_stream.dspy | 311 ++++++++++++++++++++++ b/cntoai/llm_chat_completions.dspy | 57 +++- b/cntoai/llm_chat_completions_stream.dspy | 241 +++++++++++++++++ b/cntoai/test_demo.py | 247 +++++++++++++++++ 5 files changed, 869 insertions(+), 13 deletions(-) create mode 100644 b/cntoai/chat_send_stream.dspy create mode 100644 b/cntoai/llm_chat_completions_stream.dspy create mode 100644 b/cntoai/test_demo.py diff --git a/b/cntoai/chat_send.dspy b/b/cntoai/chat_send.dspy index ba93073..60059c1 100644 --- a/b/cntoai/chat_send.dspy +++ b/b/cntoai/chat_send.dspy @@ -4,6 +4,14 @@ def _escape(value): return str(value).replace("'", "''") +def _parse_bool(value, default=True): + if value is None or value == '': + return default + if isinstance(value, bool): + return value + return str(value).lower() in ('1', 'true', 'yes', 'on') + + def _title_from_message(ns): text = ns.get('message') or ns.get('text') or '' text = str(text).strip().replace('\n', ' ') @@ -68,11 +76,15 @@ async def chat_send(ns={}): 发送消息并保存多轮对话(需先执行 chat_tables.sql)。 参数:model, message, stream(默认true), session_id, - image_url, image_base64, document_url, document_text + image_url, image_base64, document_url, document_text, + with_chunks(true时返回上游 SSE 分片列表,便于确认流式) + + 说明:本接口(chat_send.dspy)为 JSON 一次性返回。 + 需要浏览器端实时流式请调用 chat_send_stream.dspy(SSE)。 """ import json import traceback - + # model = ns.get('model') model = 'deepseek-v4-pro' if not model: @@ -115,9 +127,7 @@ async def chat_send(ns={}): }) history = await _load_session_messages(sor, session_id) - stream_val = ns.get('stream', True) - if isinstance(stream_val, str): - stream_val = stream_val.lower() in ('1', 'true', 'yes', 'on') + stream_val = _parse_bool(ns.get('stream'), True) chat_result = await path_call('llm_chat_completions.dspy', { 'model': model, 'messages': history, @@ -126,11 +136,14 @@ async def chat_send(ns={}): 'api_url': ns.get('api_url'), 'api_key': ns.get('api_key'), 'model_id': ns.get('model_id'), + 'with_chunks': ns.get('with_chunks', True), }) if not chat_result.get('status'): return chat_result reply = chat_result['data']['reply'] + chunks = chat_result['data'].get('chunks') or [] + chunk_count = chat_result['data'].get('chunk_count', 0) await sor.C('chat_message', { 'id': uuid(), 'session_id': session_id, @@ -151,6 +164,9 @@ async def chat_send(ns={}): 'session_id': session_id, 'reply': reply, 'model': model, + 'stream': stream_val, + 'chunk_count': chunk_count, + 'chunks': chunks if ns.get('with_chunks', True) else None, }, } except Exception: diff --git a/b/cntoai/chat_send_stream.dspy b/b/cntoai/chat_send_stream.dspy new file mode 100644 index 0000000..b98d116 --- /dev/null +++ b/b/cntoai/chat_send_stream.dspy @@ -0,0 +1,311 @@ +def _escape(value): + if value is None: + return None + return str(value).replace("'", "''") + + +def _parse_bool(value, default=True): + if value is None or value == '': + return default + if isinstance(value, bool): + return value + return str(value).lower() in ('1', 'true', 'yes', 'on') + + +def _title_from_message(ns): + text = ns.get('message') or ns.get('text') or '' + text = str(text).strip().replace('\n', ' ') + if not text: + return '新对话' + return text[:30] + ('...' if len(text) > 30 else '') + + +def _build_user_content(ns): + text_parts = [] + if ns.get('message'): + text_parts.append(str(ns.get('message'))) + if ns.get('text'): + text_parts.append(str(ns.get('text'))) + if ns.get('document_text'): + text_parts.append(str(ns.get('document_text'))) + + parts = [] + merged_text = '\n'.join([p for p in text_parts if p]).strip() + if merged_text: + parts.append({'type': 'text', 'text': merged_text}) + if ns.get('image_url'): + parts.append({'type': 'image_url', 'image_url': {'url': ns.get('image_url')}}) + if ns.get('image_base64'): + mime = ns.get('image_mime') or 'image/jpeg' + b64 = ns.get('image_base64') + if not str(b64).startswith('data:'): + b64 = 'data:%s;base64,%s' % (mime, b64) + parts.append({'type': 'image_url', 'image_url': {'url': b64}}) + if ns.get('document_url'): + parts.append({'type': 'file', 'file': {'file_url': ns.get('document_url')}}) + if not parts: + return '' + if len(parts) == 1 and parts[0]['type'] == 'text': + return parts[0]['text'] + return parts + + +async def _load_session_messages(sor, session_id): + sql = """ + SELECT role, content, content_type + FROM chat_message + WHERE session_id = '%s' + ORDER BY created_at ASC; + """ % _escape(session_id) + rows = await sor.sqlExe(sql, {}) + messages = [] + for row in rows: + content = row.get('content') or '' + if row.get('content_type') == 'mixed': + import json + try: + content = json.loads(content) + except Exception: + pass + messages.append({'role': row['role'], 'content': content}) + return messages + + +async def _resolve_chat_config(ns, sor): + # api_url = ns.get('api_url') + # api_key = ns.get('api_key') + api_url = 'https://api.deepseek.com/chat/completions' + api_key = 'sk-c22d6573e85a4d3fa8ab932386cf2909' + if not api_url and ns.get('model_id'): + doc_rows = await sor.sqlExe( + "SELECT api_url FROM model_api_doc WHERE model_id = '%s' LIMIT 1;" + % _escape(ns.get('model_id')), + {}, + ) + if doc_rows and doc_rows[0].get('api_url'): + api_url = doc_rows[0]['api_url'] + if not str(api_url).endswith('/chat/completions'): + api_url = str(api_url).rstrip('/') + '/chat/completions' + if not api_url: + param_rows = await sor.R('params', {'pname': 'cntoai_llm_chat_url'}) + if param_rows: + api_url = param_rows[0]['pvalue'] + else: + domain_rows = await sor.R('params', {'pname': 'cntoai_domain'}) + if domain_rows: + api_url = domain_rows[0]['pvalue'].rstrip('/') + '/llmage/v1/chat/completions' + else: + api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions' + if not api_key: + userid = ns.get('userid') or await get_user() + if userid: + action = ns.get('apikey_action') or 'user_self_create' + keys = await sor.R('user_api_keys', {'userid': userid, 'action': action}) + if not keys: + keys = await sor.R('user_api_keys', {'userid': userid, 'action': 'sync'}) + if keys: + api_key = keys[0].get('opc_apikey') + if not api_key: + key_rows = await sor.R('params', {'pname': 'cntoai_llm_api_key'}) + if key_rows: + api_key = key_rows[0]['pvalue'] + return api_url, api_key + + +def _extract_stream_piece(payload): + choice = (payload.get('choices') or [{}])[0] + delta = choice.get('delta') or {} + message = choice.get('message') or {} + piece = ( + delta.get('content') + or delta.get('reasoning_content') + or message.get('content') + or choice.get('text') + or payload.get('content') + or '' + ) + if piece is None: + return '' + return str(piece) + + +def _sse_event(obj): + import json + return 'data: %s\n\n' % json.dumps(obj, ensure_ascii=False) + + +async def _iter_upstream_stream(api_url, api_key, payload): + import aiohttp + import json + + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer %s' % api_key, + } + payload = dict(payload) + payload['stream'] = True + + buffer = '' + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600)) as session: + async with session.post(api_url, headers=headers, json=payload) as response: + if response.status != 200: + err_text = await response.text() + yield {'type': 'error', 'msg': 'HTTP %s: %s' % (response.status, err_text[:500])} + return + + async for raw in response.content: + buffer += raw.decode('utf-8', errors='ignore') + while '\n' in buffer: + line, buffer = buffer.split('\n', 1) + line = line.strip() + if not line or line.startswith(':') or not line.startswith('data:'): + continue + data = line[5:].strip() + if data == '[DONE]': + return + try: + payload_obj = json.loads(data) + piece = _extract_stream_piece(payload_obj) + if piece: + yield {'type': 'content', 'content': piece} + except Exception: + continue + + tail = buffer.strip() + if tail: + try: + body = json.loads(tail) + choice = (body.get('choices') or [{}])[0] + msg = choice.get('message') or {} + piece = msg.get('content') or choice.get('text') or '' + if piece: + yield {'type': 'content', 'content': str(piece)} + except Exception: + pass + + +async def inference_generator(request, params_kw=None, **kw): + """ + 流式 chat_send:先存 user 消息,SSE 推送 assistant 片段,结束后存库。 + + SSE 事件: + {"type":"meta","session_id":"...","model":"..."} + {"type":"content","content":"片段"} + {"type":"done","session_id":"...","reply":"完整文本","model":"..."} + {"type":"error","msg":"..."} + """ + import json + import traceback + + ns = params_kw or {} + # model = ns.get('model') + model = 'deepseek-v4-pro' + if not model: + yield _sse_event({'type': 'error', 'msg': 'model is required'}) + yield 'data: [DONE]\n\n' + return + + userid = ns.get('userid') or await get_user() + if not userid: + yield _sse_event({'type': 'error', 'msg': '未找到用户'}) + yield 'data: [DONE]\n\n' + return + + user_content = _build_user_content(ns) + if not user_content: + yield _sse_event({'type': 'error', 'msg': '请输入文本,或提供图片/文档参数'}) + yield 'data: [DONE]\n\n' + return + + content_type = 'mixed' if isinstance(user_content, list) else 'text' + store_content = json.dumps(user_content, ensure_ascii=False) if content_type == 'mixed' else str(user_content) + + db = DBPools() + async with db.sqlorContext('kboss') as sor: + try: + session_id = ns.get('session_id') + if not session_id: + session_id = uuid() + await sor.C('chat_session', { + 'id': session_id, + 'userid': userid, + 'model': model, + 'title': _title_from_message(ns), + }) + else: + sessions = await sor.R('chat_session', {'id': session_id, 'userid': userid}) + if not sessions: + yield _sse_event({'type': 'error', 'msg': '会话不存在'}) + yield 'data: [DONE]\n\n' + return + + await sor.C('chat_message', { + 'id': uuid(), + 'session_id': session_id, + 'role': 'user', + 'content': store_content, + 'content_type': content_type, + }) + + history = await _load_session_messages(sor, session_id) + api_url, api_key = await _resolve_chat_config(ns, sor) + if not api_key: + yield _sse_event({'type': 'error', 'msg': '未找到 API Key'}) + yield 'data: [DONE]\n\n' + return + + yield _sse_event({ + 'type': 'meta', + 'session_id': session_id, + 'model': model, + 'stream': True, + }) + + parts = [] + async for evt in _iter_upstream_stream(api_url, api_key, { + 'model': model, + 'messages': history, + }): + if evt.get('type') == 'error': + yield _sse_event(evt) + yield 'data: [DONE]\n\n' + return + if evt.get('type') == 'content': + parts.append(evt['content']) + yield _sse_event(evt) + + reply = ''.join(parts) + await sor.C('chat_message', { + 'id': uuid(), + 'session_id': session_id, + 'role': 'assistant', + 'content': reply, + 'content_type': 'text', + }) + await sor.sqlExe( + "UPDATE chat_session SET updated_at = NOW() WHERE id = '%s';" + % _escape(session_id), + {}, + ) + + yield _sse_event({ + 'type': 'done', + 'session_id': session_id, + 'reply': reply, + 'model': model, + }) + yield 'data: [DONE]\n\n' + except Exception: + yield _sse_event({'type': 'error', 'msg': traceback.format_exc()}) + yield 'data: [DONE]\n\n' + + +async def inference(request, *args, params_kw=None, **kw): + from functools import partial + env = request._run_ns.copy() + f = partial(inference_generator, request, params_kw=params_kw, **kw) + return await env.stream_response(request, f, content_type='text/event-stream') + + +ret = await inference(request, params_kw=params_kw) +return ret diff --git a/b/cntoai/llm_chat_completions.dspy b/b/cntoai/llm_chat_completions.dspy index 6303999..980be9a 100644 --- a/b/cntoai/llm_chat_completions.dspy +++ b/b/cntoai/llm_chat_completions.dspy @@ -81,6 +81,7 @@ def build_user_content(ns): async def _resolve_chat_config(ns, sor): """解析 API 地址与 Bearer Token""" + api_url = 'https://api.deepseek.com/chat/completions' api_key = 'sk-c22d6573e85a4d3fa8ab932386cf2909' @@ -126,8 +127,26 @@ async def _resolve_chat_config(ns, sor): return api_url, api_key +def _extract_stream_piece(payload): + """从 SSE chunk 中提取文本(兼容 OpenAI / Qwen 等格式)""" + choice = (payload.get('choices') or [{}])[0] + delta = choice.get('delta') or {} + message = choice.get('message') or {} + piece = ( + delta.get('content') + or delta.get('reasoning_content') + or message.get('content') + or choice.get('text') + or payload.get('content') + or '' + ) + if piece is None: + return '' + return str(piece) + + async def _read_stream_response(response): - """解析 SSE 流式响应,汇总 assistant 文本""" + """解析 SSE 流式响应;若上游未按 SSE 返回则回退解析整段 JSON""" import json chunks = [] buffer = '' @@ -136,21 +155,38 @@ async def _read_stream_response(response): while '\n' in buffer: line, buffer = buffer.split('\n', 1) line = line.strip() + if not line or line.startswith(':'): + continue if not line.startswith('data:'): continue data = line[5:].strip() if data == '[DONE]': - return ''.join(chunks) + return ''.join(chunks), chunks try: payload = json.loads(data) - choice = (payload.get('choices') or [{}])[0] - delta = choice.get('delta') or {} - piece = delta.get('content') or '' + piece = _extract_stream_piece(payload) if piece: chunks.append(piece) except Exception: continue - return ''.join(chunks) + + reply = ''.join(chunks) + if reply: + return reply, chunks + + # 上游可能忽略 stream=true,直接返回完整 JSON + tail = buffer.strip() + if tail: + try: + body = json.loads(tail) + choice = (body.get('choices') or [{}])[0] + msg = choice.get('message') or {} + reply = msg.get('content') or choice.get('text') or '' + if reply: + return str(reply), [str(reply)] + except Exception: + pass + return reply, chunks async def llm_chat_completions(ns={}): @@ -204,7 +240,9 @@ async def llm_chat_completions(ns={}): 'Authorization': 'Bearer %s' % api_key, } - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=600), + ) as session: async with session.post(api_url, headers=headers, json=payload) as response: if response.status != 200: err_text = await response.text() @@ -213,8 +251,9 @@ async def llm_chat_completions(ns={}): 'msg': '模型请求失败 HTTP %s: %s' % (response.status, err_text[:500]), } + stream_chunks = [] if stream: - reply = await _read_stream_response(response) + reply, stream_chunks = await _read_stream_response(response) usage = {} else: body = await response.json() @@ -232,6 +271,8 @@ async def llm_chat_completions(ns={}): 'messages': messages + [{'role': 'assistant', 'content': reply}], 'usage': usage, 'stream': stream, + 'chunk_count': len(stream_chunks), + 'chunks': stream_chunks if ns.get('with_chunks') else None, }, } except Exception: diff --git a/b/cntoai/llm_chat_completions_stream.dspy b/b/cntoai/llm_chat_completions_stream.dspy new file mode 100644 index 0000000..24600ff --- /dev/null +++ b/b/cntoai/llm_chat_completions_stream.dspy @@ -0,0 +1,241 @@ +def _escape(value): + if value is None: + return None + return str(value).replace("'", "''") + + +def _parse_bool(value, default=True): + if value is None or value == '': + return default + if isinstance(value, bool): + return value + return str(value).lower() in ('1', 'true', 'yes', 'on') + + +def _parse_messages(ns): + raw = ns.get('messages') + if not raw: + return [] + if isinstance(raw, list): + return raw + if isinstance(raw, str): + import json + try: + return json.loads(raw) + except Exception: + return [] + return [] + + +def build_user_content(ns): + text_parts = [] + if ns.get('message'): + text_parts.append(str(ns.get('message'))) + if ns.get('text'): + text_parts.append(str(ns.get('text'))) + if ns.get('document_text'): + text_parts.append(str(ns.get('document_text'))) + + parts = [] + merged_text = '\n'.join([p for p in text_parts if p]).strip() + if merged_text: + parts.append({'type': 'text', 'text': merged_text}) + if ns.get('image_url'): + parts.append({'type': 'image_url', 'image_url': {'url': ns.get('image_url')}}) + if ns.get('image_base64'): + mime = ns.get('image_mime') or 'image/jpeg' + b64 = ns.get('image_base64') + if not str(b64).startswith('data:'): + b64 = 'data:%s;base64,%s' % (mime, b64) + parts.append({'type': 'image_url', 'image_url': {'url': b64}}) + if ns.get('document_url'): + parts.append({'type': 'file', 'file': {'file_url': ns.get('document_url')}}) + if not parts: + return '' + if len(parts) == 1 and parts[0]['type'] == 'text': + return parts[0]['text'] + return parts + + +async def _resolve_chat_config(ns, sor): + api_url = ns.get('api_url') + api_key = ns.get('api_key') + if not api_url and ns.get('model_id'): + doc_rows = await sor.sqlExe( + "SELECT api_url FROM model_api_doc WHERE model_id = '%s' LIMIT 1;" + % _escape(ns.get('model_id')), + {}, + ) + if doc_rows and doc_rows[0].get('api_url'): + api_url = doc_rows[0]['api_url'] + if not str(api_url).endswith('/chat/completions'): + api_url = str(api_url).rstrip('/') + '/chat/completions' + if not api_url: + param_rows = await sor.R('params', {'pname': 'cntoai_llm_chat_url'}) + if param_rows: + api_url = param_rows[0]['pvalue'] + else: + domain_rows = await sor.R('params', {'pname': 'cntoai_domain'}) + if domain_rows: + api_url = domain_rows[0]['pvalue'].rstrip('/') + '/llmage/v1/chat/completions' + else: + api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions' + if not api_key: + userid = ns.get('userid') or await get_user() + if userid: + action = ns.get('apikey_action') or 'user_self_create' + keys = await sor.R('user_api_keys', {'userid': userid, 'action': action}) + if not keys: + keys = await sor.R('user_api_keys', {'userid': userid, 'action': 'sync'}) + if keys: + api_key = keys[0].get('opc_apikey') + if not api_key: + key_rows = await sor.R('params', {'pname': 'cntoai_llm_api_key'}) + if key_rows: + api_key = key_rows[0]['pvalue'] + return api_url, api_key + + +def _extract_stream_piece(payload): + choice = (payload.get('choices') or [{}])[0] + delta = choice.get('delta') or {} + message = choice.get('message') or {} + piece = ( + delta.get('content') + or delta.get('reasoning_content') + or message.get('content') + or choice.get('text') + or payload.get('content') + or '' + ) + if piece is None: + return '' + return str(piece) + + +def _sse_event(obj): + import json + return 'data: %s\n\n' % json.dumps(obj, ensure_ascii=False) + + +async def _iter_upstream_stream(api_url, api_key, payload): + """向上游发起流式请求,逐片 yield 文本""" + import aiohttp + import json + + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer %s' % api_key, + } + payload = dict(payload) + payload['stream'] = True + + buffer = '' + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600)) as session: + async with session.post(api_url, headers=headers, json=payload) as response: + if response.status != 200: + err_text = await response.text() + yield {'type': 'error', 'msg': 'HTTP %s: %s' % (response.status, err_text[:500])} + return + + async for raw in response.content: + buffer += raw.decode('utf-8', errors='ignore') + while '\n' in buffer: + line, buffer = buffer.split('\n', 1) + line = line.strip() + if not line or line.startswith(':') or not line.startswith('data:'): + continue + data = line[5:].strip() + if data == '[DONE]': + return + try: + payload_obj = json.loads(data) + piece = _extract_stream_piece(payload_obj) + if piece: + yield {'type': 'content', 'content': piece} + except Exception: + continue + + tail = buffer.strip() + if tail: + try: + body = json.loads(tail) + choice = (body.get('choices') or [{}])[0] + msg = choice.get('message') or {} + piece = msg.get('content') or choice.get('text') or '' + if piece: + yield {'type': 'content', 'content': str(piece)} + except Exception: + pass + + +async def inference_generator(request, params_kw=None, **kw): + """ + SSE 流式输出,事件格式: + {"type":"meta","model":"..."} + {"type":"content","content":"片段"} + {"type":"done","reply":"完整文本"} + {"type":"error","msg":"..."} + 结束:data: [DONE] + """ + import traceback + + ns = params_kw or {} + model = ns.get('model') + if not model: + yield _sse_event({'type': 'error', 'msg': 'model is required'}) + yield 'data: [DONE]\n\n' + return + + history = _parse_messages(ns) + user_content = build_user_content(ns) + if not user_content and not history: + yield _sse_event({'type': 'error', 'msg': 'message is required'}) + yield 'data: [DONE]\n\n' + return + + messages = list(history) + if user_content: + messages.append({'role': 'user', 'content': user_content}) + + db = DBPools() + async with db.sqlorContext('kboss') as sor: + try: + api_url, api_key = await _resolve_chat_config(ns, sor) + if not api_key: + yield _sse_event({'type': 'error', 'msg': '未找到 API Key'}) + yield 'data: [DONE]\n\n' + return + + yield _sse_event({'type': 'meta', 'model': model, 'stream': True}) + + parts = [] + async for evt in _iter_upstream_stream(api_url, api_key, { + 'model': model, + 'messages': messages, + }): + if evt.get('type') == 'error': + yield _sse_event(evt) + yield 'data: [DONE]\n\n' + return + if evt.get('type') == 'content': + parts.append(evt['content']) + yield _sse_event(evt) + + reply = ''.join(parts) + yield _sse_event({'type': 'done', 'reply': reply, 'model': model}) + yield 'data: [DONE]\n\n' + except Exception: + yield _sse_event({'type': 'error', 'msg': traceback.format_exc()}) + yield 'data: [DONE]\n\n' + + +async def inference(request, *args, params_kw=None, **kw): + from functools import partial + env = request._run_ns.copy() + f = partial(inference_generator, request, params_kw=params_kw, **kw) + return await env.stream_response(request, f, content_type='text/event-stream') + + +ret = await inference(request, params_kw=params_kw) +return ret diff --git a/b/cntoai/test_demo.py b/b/cntoai/test_demo.py new file mode 100644 index 0000000..b04e887 --- /dev/null +++ b/b/cntoai/test_demo.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +chat_send_stream.dspy SSE 流式接口测试 + +用法: + pip install requests + python test_demo.py + +环境变量(可选,覆盖下方默认值): + CNTOAI_BASE_URL / CNTOAI_USERID / CNTOAI_API_KEY + CNTOAI_MODEL / CNTOAI_LLM_API_URL / CNTOAI_MESSAGE +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from typing import Any, Dict, Generator, List, Optional + +try: + import requests +except ImportError: + print("请先安装: pip install requests") + sys.exit(1) + + +BASE_URL = os.environ.get("CNTOAI_BASE_URL", "https://dev.opencomputing.cn").rstrip("/") +USERID = os.environ.get("CNTOAI_USERID", "hSqZuekZ1yKmhKmCN9UAK").strip() +API_KEY = os.environ.get("CNTOAI_API_KEY", "sk-c22d6573e85a4d3fa8ab932386cf2909").strip() +API_URL = os.environ.get("CNTOAI_LLM_API_URL", "https://api.deepseek.com/v1/chat/completions").strip() +MODEL = os.environ.get("CNTOAI_MODEL", "deepseek-chat").strip() +MESSAGE = os.environ.get("CNTOAI_MESSAGE", "你好,请用三句话介绍你自己").strip() +TIMEOUT = int(os.environ.get("CNTOAI_TIMEOUT", "300")) + +STREAM_PATH = "/cntoai/chat_send_stream.dspy" + + +def build_payload(session_id: Optional[str] = None, message: Optional[str] = None) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "model": MODEL, + "message": message or MESSAGE, + "userid": USERID, + "api_key": API_KEY, + "api_url": API_URL, + } + if session_id: + payload["session_id"] = session_id + return payload + + +def parse_sse_text(text: str) -> List[Dict[str, Any]]: + events: List[Dict[str, Any]] = [] + for line in text.splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + data = line[5:].strip() + if data == "[DONE]": + break + try: + events.append(json.loads(data)) + except json.JSONDecodeError: + print(f"[warn] 无法解析: {line[:200]}") + return events + + +def parse_sse_stream(response: requests.Response) -> Generator[Dict[str, Any], None, None]: + """ + 按字节缓冲解析 SSE。 + 勿用 iter_lines(decode_unicode=True):TCP 分块可能截断 UTF-8 多字节字符,导致乱码和 JSON 解析失败。 + """ + buffer = b"" + for chunk in response.iter_content(chunk_size=4096): + if not chunk: + continue + buffer += chunk + while b"\n" in buffer: + line_bytes, buffer = buffer.split(b"\n", 1) + if not line_bytes.strip(): + continue + line = line_bytes.decode("utf-8").strip() + if not line.startswith("data:"): + continue + data = line[5:].strip() + if data == "[DONE]": + return + try: + yield json.loads(data) + except json.JSONDecodeError: + print(f"\n[warn] JSON 解析失败: {line[:120]}...") + + tail = buffer.strip() + if tail: + line = tail.decode("utf-8", errors="replace").strip() + if line.startswith("data:"): + data = line[5:].strip() + if data and data != "[DONE]": + try: + yield json.loads(data) + except json.JSONDecodeError: + pass + + +def diagnose_empty_response(resp: requests.Response) -> None: + ctype = resp.headers.get("Content-Type", "") + body = resp.content or b"" + print("\n[诊断] 响应体为空或无可解析 SSE") + print(f" Content-Type : {ctype}") + print(f" body 长度 : {len(body)}") + if body: + print(f" body 前 500B : {body[:500]!r}") + if "text/html" in ctype and len(body) == 0: + print("\n 可能原因: chat_send_stream.dspy 未执行 inference 入口。") + print(" 请确认文件末尾包含:") + print(" ret = await inference(request, params_kw=params_kw)") + print(" return ret") + print(" 并重新部署到 dev 后再测。") + + +def test_chat_send_stream(session_id: Optional[str] = None, message: Optional[str] = None) -> Optional[str]: + url = BASE_URL + STREAM_PATH + payload = build_payload(session_id=session_id, message=message) + + print("=" * 60) + print("chat_send_stream.dspy 流式测试") + print(f" URL : {url}") + print(f" MODEL : {MODEL}") + print(f" USERID : {USERID}") + print(f" API_URL : {API_URL}") + print(f" message : {payload.get('message')}") + if session_id: + print(f" session : {session_id}") + print("=" * 60) + + if not USERID: + print("错误: 请设置 CNTOAI_USERID") + return None + + resp = requests.post( + url, + json=payload, + headers={ + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + stream=True, + timeout=TIMEOUT, + ) + + ctype = resp.headers.get("Content-Type", "") + print(f"\nHTTP {resp.status_code} Content-Type: {ctype}\n") + + if resp.status_code != 200: + print(resp.text[:500]) + return None + + if "text/event-stream" not in ctype: + raw = resp.content + diagnose_empty_response(resp) + if raw: + for evt in parse_sse_text(raw.decode("utf-8", errors="ignore")): + print("[parsed]", evt) + return None + + session_out: Optional[str] = session_id + full_reply: List[str] = [] + has_content = False + event_count = 0 + + print("--- 流式输出 ---") + for evt in parse_sse_stream(resp): + event_count += 1 + etype = evt.get("type") + + if etype == "meta": + session_out = evt.get("session_id") or session_out + print(f"[meta] session_id={session_out} model={evt.get('model')}") + continue + + if etype == "content": + piece = evt.get("content") or "" + has_content = True + full_reply.append(piece) + print(piece, end="", flush=True) + continue + + if etype == "done": + session_out = evt.get("session_id") or session_out + reply = evt.get("reply") or "" + print(f"\n\n[done] session_id={session_out}") + print(f"[done] reply 长度={len(reply)}") + if reply and not has_content: + print(reply) + continue + + if etype == "error": + print(f"\n[error] {evt.get('msg')}") + return session_out + + print(f"\n[unknown] {evt}") + + print("\n--- 结束 ---") + if event_count == 0: + diagnose_empty_response(resp) + elif full_reply: + joined = "".join(full_reply) + print(f"拼接回复({len(joined)}字): {joined[:300]}...") + return session_out + + +def main() -> int: + if sys.platform == "win32": + try: + sys.stdout.reconfigure(encoding="utf-8") + except Exception: + pass + + parser = argparse.ArgumentParser(description="chat_send_stream.dspy SSE 测试") + parser.add_argument("--session-id", help="续聊会话 ID") + parser.add_argument("--message", "-m", help="覆盖默认 message") + parser.add_argument("--twice", action="store_true", help="同一会话连发两条") + args = parser.parse_args() + + sid = test_chat_send_stream(session_id=args.session_id, message=args.message) + if sid is None: + return 1 + + if args.twice and sid: + print("\n" + "=" * 60) + print("第二轮(多轮续聊)") + sid2 = test_chat_send_stream( + session_id=sid, + message=args.message or "继续,用一句话总结上面内容", + ) + if sid2 is None: + return 1 + + if sid: + print(f"\n提示: 续聊 python test_demo.py --session-id {sid}") + return 0 + + +if __name__ == "__main__": + sys.exit(main())