update
This commit is contained in:
parent
f60c761735
commit
7bb029a66e
@ -4,6 +4,14 @@ def _escape(value):
|
||||
return str(value).replace("'", "''")
|
||||
|
||||
|
||||
def _parse_bool(value, default=True):
|
||||
if value is None or value == '':
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return str(value).lower() in ('1', 'true', 'yes', 'on')
|
||||
|
||||
|
||||
def _title_from_message(ns):
|
||||
text = ns.get('message') or ns.get('text') or ''
|
||||
text = str(text).strip().replace('\n', ' ')
|
||||
@ -68,11 +76,15 @@ async def chat_send(ns={}):
|
||||
发送消息并保存多轮对话(需先执行 chat_tables.sql)。
|
||||
|
||||
参数:model, message, stream(默认true), session_id,
|
||||
image_url, image_base64, document_url, document_text
|
||||
image_url, image_base64, document_url, document_text,
|
||||
with_chunks(true时返回上游 SSE 分片列表,便于确认流式)
|
||||
|
||||
说明:本接口(chat_send.dspy)为 JSON 一次性返回。
|
||||
需要浏览器端实时流式请调用 chat_send_stream.dspy(SSE)。
|
||||
"""
|
||||
import json
|
||||
import traceback
|
||||
|
||||
|
||||
# model = ns.get('model')
|
||||
model = 'deepseek-v4-pro'
|
||||
if not model:
|
||||
@ -115,9 +127,7 @@ async def chat_send(ns={}):
|
||||
})
|
||||
|
||||
history = await _load_session_messages(sor, session_id)
|
||||
stream_val = ns.get('stream', True)
|
||||
if isinstance(stream_val, str):
|
||||
stream_val = stream_val.lower() in ('1', 'true', 'yes', 'on')
|
||||
stream_val = _parse_bool(ns.get('stream'), True)
|
||||
chat_result = await path_call('llm_chat_completions.dspy', {
|
||||
'model': model,
|
||||
'messages': history,
|
||||
@ -126,11 +136,14 @@ async def chat_send(ns={}):
|
||||
'api_url': ns.get('api_url'),
|
||||
'api_key': ns.get('api_key'),
|
||||
'model_id': ns.get('model_id'),
|
||||
'with_chunks': ns.get('with_chunks', True),
|
||||
})
|
||||
if not chat_result.get('status'):
|
||||
return chat_result
|
||||
|
||||
reply = chat_result['data']['reply']
|
||||
chunks = chat_result['data'].get('chunks') or []
|
||||
chunk_count = chat_result['data'].get('chunk_count', 0)
|
||||
await sor.C('chat_message', {
|
||||
'id': uuid(),
|
||||
'session_id': session_id,
|
||||
@ -151,6 +164,9 @@ async def chat_send(ns={}):
|
||||
'session_id': session_id,
|
||||
'reply': reply,
|
||||
'model': model,
|
||||
'stream': stream_val,
|
||||
'chunk_count': chunk_count,
|
||||
'chunks': chunks if ns.get('with_chunks', True) else None,
|
||||
},
|
||||
}
|
||||
except Exception:
|
||||
|
||||
311
b/cntoai/chat_send_stream.dspy
Normal file
311
b/cntoai/chat_send_stream.dspy
Normal file
@ -0,0 +1,311 @@
|
||||
def _escape(value):
|
||||
if value is None:
|
||||
return None
|
||||
return str(value).replace("'", "''")
|
||||
|
||||
|
||||
def _parse_bool(value, default=True):
|
||||
if value is None or value == '':
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return str(value).lower() in ('1', 'true', 'yes', 'on')
|
||||
|
||||
|
||||
def _title_from_message(ns):
|
||||
text = ns.get('message') or ns.get('text') or ''
|
||||
text = str(text).strip().replace('\n', ' ')
|
||||
if not text:
|
||||
return '新对话'
|
||||
return text[:30] + ('...' if len(text) > 30 else '')
|
||||
|
||||
|
||||
def _build_user_content(ns):
|
||||
text_parts = []
|
||||
if ns.get('message'):
|
||||
text_parts.append(str(ns.get('message')))
|
||||
if ns.get('text'):
|
||||
text_parts.append(str(ns.get('text')))
|
||||
if ns.get('document_text'):
|
||||
text_parts.append(str(ns.get('document_text')))
|
||||
|
||||
parts = []
|
||||
merged_text = '\n'.join([p for p in text_parts if p]).strip()
|
||||
if merged_text:
|
||||
parts.append({'type': 'text', 'text': merged_text})
|
||||
if ns.get('image_url'):
|
||||
parts.append({'type': 'image_url', 'image_url': {'url': ns.get('image_url')}})
|
||||
if ns.get('image_base64'):
|
||||
mime = ns.get('image_mime') or 'image/jpeg'
|
||||
b64 = ns.get('image_base64')
|
||||
if not str(b64).startswith('data:'):
|
||||
b64 = 'data:%s;base64,%s' % (mime, b64)
|
||||
parts.append({'type': 'image_url', 'image_url': {'url': b64}})
|
||||
if ns.get('document_url'):
|
||||
parts.append({'type': 'file', 'file': {'file_url': ns.get('document_url')}})
|
||||
if not parts:
|
||||
return ''
|
||||
if len(parts) == 1 and parts[0]['type'] == 'text':
|
||||
return parts[0]['text']
|
||||
return parts
|
||||
|
||||
|
||||
async def _load_session_messages(sor, session_id):
|
||||
sql = """
|
||||
SELECT role, content, content_type
|
||||
FROM chat_message
|
||||
WHERE session_id = '%s'
|
||||
ORDER BY created_at ASC;
|
||||
""" % _escape(session_id)
|
||||
rows = await sor.sqlExe(sql, {})
|
||||
messages = []
|
||||
for row in rows:
|
||||
content = row.get('content') or ''
|
||||
if row.get('content_type') == 'mixed':
|
||||
import json
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except Exception:
|
||||
pass
|
||||
messages.append({'role': row['role'], 'content': content})
|
||||
return messages
|
||||
|
||||
|
||||
async def _resolve_chat_config(ns, sor):
|
||||
# api_url = ns.get('api_url')
|
||||
# api_key = ns.get('api_key')
|
||||
api_url = 'https://api.deepseek.com/chat/completions'
|
||||
api_key = 'sk-c22d6573e85a4d3fa8ab932386cf2909'
|
||||
if not api_url and ns.get('model_id'):
|
||||
doc_rows = await sor.sqlExe(
|
||||
"SELECT api_url FROM model_api_doc WHERE model_id = '%s' LIMIT 1;"
|
||||
% _escape(ns.get('model_id')),
|
||||
{},
|
||||
)
|
||||
if doc_rows and doc_rows[0].get('api_url'):
|
||||
api_url = doc_rows[0]['api_url']
|
||||
if not str(api_url).endswith('/chat/completions'):
|
||||
api_url = str(api_url).rstrip('/') + '/chat/completions'
|
||||
if not api_url:
|
||||
param_rows = await sor.R('params', {'pname': 'cntoai_llm_chat_url'})
|
||||
if param_rows:
|
||||
api_url = param_rows[0]['pvalue']
|
||||
else:
|
||||
domain_rows = await sor.R('params', {'pname': 'cntoai_domain'})
|
||||
if domain_rows:
|
||||
api_url = domain_rows[0]['pvalue'].rstrip('/') + '/llmage/v1/chat/completions'
|
||||
else:
|
||||
api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions'
|
||||
if not api_key:
|
||||
userid = ns.get('userid') or await get_user()
|
||||
if userid:
|
||||
action = ns.get('apikey_action') or 'user_self_create'
|
||||
keys = await sor.R('user_api_keys', {'userid': userid, 'action': action})
|
||||
if not keys:
|
||||
keys = await sor.R('user_api_keys', {'userid': userid, 'action': 'sync'})
|
||||
if keys:
|
||||
api_key = keys[0].get('opc_apikey')
|
||||
if not api_key:
|
||||
key_rows = await sor.R('params', {'pname': 'cntoai_llm_api_key'})
|
||||
if key_rows:
|
||||
api_key = key_rows[0]['pvalue']
|
||||
return api_url, api_key
|
||||
|
||||
|
||||
def _extract_stream_piece(payload):
|
||||
choice = (payload.get('choices') or [{}])[0]
|
||||
delta = choice.get('delta') or {}
|
||||
message = choice.get('message') or {}
|
||||
piece = (
|
||||
delta.get('content')
|
||||
or delta.get('reasoning_content')
|
||||
or message.get('content')
|
||||
or choice.get('text')
|
||||
or payload.get('content')
|
||||
or ''
|
||||
)
|
||||
if piece is None:
|
||||
return ''
|
||||
return str(piece)
|
||||
|
||||
|
||||
def _sse_event(obj):
|
||||
import json
|
||||
return 'data: %s\n\n' % json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
|
||||
async def _iter_upstream_stream(api_url, api_key, payload):
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer %s' % api_key,
|
||||
}
|
||||
payload = dict(payload)
|
||||
payload['stream'] = True
|
||||
|
||||
buffer = ''
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600)) as session:
|
||||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
err_text = await response.text()
|
||||
yield {'type': 'error', 'msg': 'HTTP %s: %s' % (response.status, err_text[:500])}
|
||||
return
|
||||
|
||||
async for raw in response.content:
|
||||
buffer += raw.decode('utf-8', errors='ignore')
|
||||
while '\n' in buffer:
|
||||
line, buffer = buffer.split('\n', 1)
|
||||
line = line.strip()
|
||||
if not line or line.startswith(':') or not line.startswith('data:'):
|
||||
continue
|
||||
data = line[5:].strip()
|
||||
if data == '[DONE]':
|
||||
return
|
||||
try:
|
||||
payload_obj = json.loads(data)
|
||||
piece = _extract_stream_piece(payload_obj)
|
||||
if piece:
|
||||
yield {'type': 'content', 'content': piece}
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
tail = buffer.strip()
|
||||
if tail:
|
||||
try:
|
||||
body = json.loads(tail)
|
||||
choice = (body.get('choices') or [{}])[0]
|
||||
msg = choice.get('message') or {}
|
||||
piece = msg.get('content') or choice.get('text') or ''
|
||||
if piece:
|
||||
yield {'type': 'content', 'content': str(piece)}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def inference_generator(request, params_kw=None, **kw):
|
||||
"""
|
||||
流式 chat_send:先存 user 消息,SSE 推送 assistant 片段,结束后存库。
|
||||
|
||||
SSE 事件:
|
||||
{"type":"meta","session_id":"...","model":"..."}
|
||||
{"type":"content","content":"片段"}
|
||||
{"type":"done","session_id":"...","reply":"完整文本","model":"..."}
|
||||
{"type":"error","msg":"..."}
|
||||
"""
|
||||
import json
|
||||
import traceback
|
||||
|
||||
ns = params_kw or {}
|
||||
# model = ns.get('model')
|
||||
model = 'deepseek-v4-pro'
|
||||
if not model:
|
||||
yield _sse_event({'type': 'error', 'msg': 'model is required'})
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
|
||||
userid = ns.get('userid') or await get_user()
|
||||
if not userid:
|
||||
yield _sse_event({'type': 'error', 'msg': '未找到用户'})
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
|
||||
user_content = _build_user_content(ns)
|
||||
if not user_content:
|
||||
yield _sse_event({'type': 'error', 'msg': '请输入文本,或提供图片/文档参数'})
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
|
||||
content_type = 'mixed' if isinstance(user_content, list) else 'text'
|
||||
store_content = json.dumps(user_content, ensure_ascii=False) if content_type == 'mixed' else str(user_content)
|
||||
|
||||
db = DBPools()
|
||||
async with db.sqlorContext('kboss') as sor:
|
||||
try:
|
||||
session_id = ns.get('session_id')
|
||||
if not session_id:
|
||||
session_id = uuid()
|
||||
await sor.C('chat_session', {
|
||||
'id': session_id,
|
||||
'userid': userid,
|
||||
'model': model,
|
||||
'title': _title_from_message(ns),
|
||||
})
|
||||
else:
|
||||
sessions = await sor.R('chat_session', {'id': session_id, 'userid': userid})
|
||||
if not sessions:
|
||||
yield _sse_event({'type': 'error', 'msg': '会话不存在'})
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
|
||||
await sor.C('chat_message', {
|
||||
'id': uuid(),
|
||||
'session_id': session_id,
|
||||
'role': 'user',
|
||||
'content': store_content,
|
||||
'content_type': content_type,
|
||||
})
|
||||
|
||||
history = await _load_session_messages(sor, session_id)
|
||||
api_url, api_key = await _resolve_chat_config(ns, sor)
|
||||
if not api_key:
|
||||
yield _sse_event({'type': 'error', 'msg': '未找到 API Key'})
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
|
||||
yield _sse_event({
|
||||
'type': 'meta',
|
||||
'session_id': session_id,
|
||||
'model': model,
|
||||
'stream': True,
|
||||
})
|
||||
|
||||
parts = []
|
||||
async for evt in _iter_upstream_stream(api_url, api_key, {
|
||||
'model': model,
|
||||
'messages': history,
|
||||
}):
|
||||
if evt.get('type') == 'error':
|
||||
yield _sse_event(evt)
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
if evt.get('type') == 'content':
|
||||
parts.append(evt['content'])
|
||||
yield _sse_event(evt)
|
||||
|
||||
reply = ''.join(parts)
|
||||
await sor.C('chat_message', {
|
||||
'id': uuid(),
|
||||
'session_id': session_id,
|
||||
'role': 'assistant',
|
||||
'content': reply,
|
||||
'content_type': 'text',
|
||||
})
|
||||
await sor.sqlExe(
|
||||
"UPDATE chat_session SET updated_at = NOW() WHERE id = '%s';"
|
||||
% _escape(session_id),
|
||||
{},
|
||||
)
|
||||
|
||||
yield _sse_event({
|
||||
'type': 'done',
|
||||
'session_id': session_id,
|
||||
'reply': reply,
|
||||
'model': model,
|
||||
})
|
||||
yield 'data: [DONE]\n\n'
|
||||
except Exception:
|
||||
yield _sse_event({'type': 'error', 'msg': traceback.format_exc()})
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
|
||||
async def inference(request, *args, params_kw=None, **kw):
|
||||
from functools import partial
|
||||
env = request._run_ns.copy()
|
||||
f = partial(inference_generator, request, params_kw=params_kw, **kw)
|
||||
return await env.stream_response(request, f, content_type='text/event-stream')
|
||||
|
||||
|
||||
ret = await inference(request, params_kw=params_kw)
|
||||
return ret
|
||||
@ -81,6 +81,7 @@ def build_user_content(ns):
|
||||
|
||||
async def _resolve_chat_config(ns, sor):
|
||||
"""解析 API 地址与 Bearer Token"""
|
||||
|
||||
api_url = 'https://api.deepseek.com/chat/completions'
|
||||
api_key = 'sk-c22d6573e85a4d3fa8ab932386cf2909'
|
||||
|
||||
@ -126,8 +127,26 @@ async def _resolve_chat_config(ns, sor):
|
||||
return api_url, api_key
|
||||
|
||||
|
||||
def _extract_stream_piece(payload):
|
||||
"""从 SSE chunk 中提取文本(兼容 OpenAI / Qwen 等格式)"""
|
||||
choice = (payload.get('choices') or [{}])[0]
|
||||
delta = choice.get('delta') or {}
|
||||
message = choice.get('message') or {}
|
||||
piece = (
|
||||
delta.get('content')
|
||||
or delta.get('reasoning_content')
|
||||
or message.get('content')
|
||||
or choice.get('text')
|
||||
or payload.get('content')
|
||||
or ''
|
||||
)
|
||||
if piece is None:
|
||||
return ''
|
||||
return str(piece)
|
||||
|
||||
|
||||
async def _read_stream_response(response):
|
||||
"""解析 SSE 流式响应,汇总 assistant 文本"""
|
||||
"""解析 SSE 流式响应;若上游未按 SSE 返回则回退解析整段 JSON"""
|
||||
import json
|
||||
chunks = []
|
||||
buffer = ''
|
||||
@ -136,21 +155,38 @@ async def _read_stream_response(response):
|
||||
while '\n' in buffer:
|
||||
line, buffer = buffer.split('\n', 1)
|
||||
line = line.strip()
|
||||
if not line or line.startswith(':'):
|
||||
continue
|
||||
if not line.startswith('data:'):
|
||||
continue
|
||||
data = line[5:].strip()
|
||||
if data == '[DONE]':
|
||||
return ''.join(chunks)
|
||||
return ''.join(chunks), chunks
|
||||
try:
|
||||
payload = json.loads(data)
|
||||
choice = (payload.get('choices') or [{}])[0]
|
||||
delta = choice.get('delta') or {}
|
||||
piece = delta.get('content') or ''
|
||||
piece = _extract_stream_piece(payload)
|
||||
if piece:
|
||||
chunks.append(piece)
|
||||
except Exception:
|
||||
continue
|
||||
return ''.join(chunks)
|
||||
|
||||
reply = ''.join(chunks)
|
||||
if reply:
|
||||
return reply, chunks
|
||||
|
||||
# 上游可能忽略 stream=true,直接返回完整 JSON
|
||||
tail = buffer.strip()
|
||||
if tail:
|
||||
try:
|
||||
body = json.loads(tail)
|
||||
choice = (body.get('choices') or [{}])[0]
|
||||
msg = choice.get('message') or {}
|
||||
reply = msg.get('content') or choice.get('text') or ''
|
||||
if reply:
|
||||
return str(reply), [str(reply)]
|
||||
except Exception:
|
||||
pass
|
||||
return reply, chunks
|
||||
|
||||
|
||||
async def llm_chat_completions(ns={}):
|
||||
@ -204,7 +240,9 @@ async def llm_chat_completions(ns={}):
|
||||
'Authorization': 'Bearer %s' % api_key,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=600),
|
||||
) as session:
|
||||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
err_text = await response.text()
|
||||
@ -213,8 +251,9 @@ async def llm_chat_completions(ns={}):
|
||||
'msg': '模型请求失败 HTTP %s: %s' % (response.status, err_text[:500]),
|
||||
}
|
||||
|
||||
stream_chunks = []
|
||||
if stream:
|
||||
reply = await _read_stream_response(response)
|
||||
reply, stream_chunks = await _read_stream_response(response)
|
||||
usage = {}
|
||||
else:
|
||||
body = await response.json()
|
||||
@ -232,6 +271,8 @@ async def llm_chat_completions(ns={}):
|
||||
'messages': messages + [{'role': 'assistant', 'content': reply}],
|
||||
'usage': usage,
|
||||
'stream': stream,
|
||||
'chunk_count': len(stream_chunks),
|
||||
'chunks': stream_chunks if ns.get('with_chunks') else None,
|
||||
},
|
||||
}
|
||||
except Exception:
|
||||
|
||||
241
b/cntoai/llm_chat_completions_stream.dspy
Normal file
241
b/cntoai/llm_chat_completions_stream.dspy
Normal file
@ -0,0 +1,241 @@
|
||||
def _escape(value):
|
||||
if value is None:
|
||||
return None
|
||||
return str(value).replace("'", "''")
|
||||
|
||||
|
||||
def _parse_bool(value, default=True):
|
||||
if value is None or value == '':
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return str(value).lower() in ('1', 'true', 'yes', 'on')
|
||||
|
||||
|
||||
def _parse_messages(ns):
|
||||
raw = ns.get('messages')
|
||||
if not raw:
|
||||
return []
|
||||
if isinstance(raw, list):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
import json
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except Exception:
|
||||
return []
|
||||
return []
|
||||
|
||||
|
||||
def build_user_content(ns):
|
||||
text_parts = []
|
||||
if ns.get('message'):
|
||||
text_parts.append(str(ns.get('message')))
|
||||
if ns.get('text'):
|
||||
text_parts.append(str(ns.get('text')))
|
||||
if ns.get('document_text'):
|
||||
text_parts.append(str(ns.get('document_text')))
|
||||
|
||||
parts = []
|
||||
merged_text = '\n'.join([p for p in text_parts if p]).strip()
|
||||
if merged_text:
|
||||
parts.append({'type': 'text', 'text': merged_text})
|
||||
if ns.get('image_url'):
|
||||
parts.append({'type': 'image_url', 'image_url': {'url': ns.get('image_url')}})
|
||||
if ns.get('image_base64'):
|
||||
mime = ns.get('image_mime') or 'image/jpeg'
|
||||
b64 = ns.get('image_base64')
|
||||
if not str(b64).startswith('data:'):
|
||||
b64 = 'data:%s;base64,%s' % (mime, b64)
|
||||
parts.append({'type': 'image_url', 'image_url': {'url': b64}})
|
||||
if ns.get('document_url'):
|
||||
parts.append({'type': 'file', 'file': {'file_url': ns.get('document_url')}})
|
||||
if not parts:
|
||||
return ''
|
||||
if len(parts) == 1 and parts[0]['type'] == 'text':
|
||||
return parts[0]['text']
|
||||
return parts
|
||||
|
||||
|
||||
async def _resolve_chat_config(ns, sor):
|
||||
api_url = ns.get('api_url')
|
||||
api_key = ns.get('api_key')
|
||||
if not api_url and ns.get('model_id'):
|
||||
doc_rows = await sor.sqlExe(
|
||||
"SELECT api_url FROM model_api_doc WHERE model_id = '%s' LIMIT 1;"
|
||||
% _escape(ns.get('model_id')),
|
||||
{},
|
||||
)
|
||||
if doc_rows and doc_rows[0].get('api_url'):
|
||||
api_url = doc_rows[0]['api_url']
|
||||
if not str(api_url).endswith('/chat/completions'):
|
||||
api_url = str(api_url).rstrip('/') + '/chat/completions'
|
||||
if not api_url:
|
||||
param_rows = await sor.R('params', {'pname': 'cntoai_llm_chat_url'})
|
||||
if param_rows:
|
||||
api_url = param_rows[0]['pvalue']
|
||||
else:
|
||||
domain_rows = await sor.R('params', {'pname': 'cntoai_domain'})
|
||||
if domain_rows:
|
||||
api_url = domain_rows[0]['pvalue'].rstrip('/') + '/llmage/v1/chat/completions'
|
||||
else:
|
||||
api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions'
|
||||
if not api_key:
|
||||
userid = ns.get('userid') or await get_user()
|
||||
if userid:
|
||||
action = ns.get('apikey_action') or 'user_self_create'
|
||||
keys = await sor.R('user_api_keys', {'userid': userid, 'action': action})
|
||||
if not keys:
|
||||
keys = await sor.R('user_api_keys', {'userid': userid, 'action': 'sync'})
|
||||
if keys:
|
||||
api_key = keys[0].get('opc_apikey')
|
||||
if not api_key:
|
||||
key_rows = await sor.R('params', {'pname': 'cntoai_llm_api_key'})
|
||||
if key_rows:
|
||||
api_key = key_rows[0]['pvalue']
|
||||
return api_url, api_key
|
||||
|
||||
|
||||
def _extract_stream_piece(payload):
|
||||
choice = (payload.get('choices') or [{}])[0]
|
||||
delta = choice.get('delta') or {}
|
||||
message = choice.get('message') or {}
|
||||
piece = (
|
||||
delta.get('content')
|
||||
or delta.get('reasoning_content')
|
||||
or message.get('content')
|
||||
or choice.get('text')
|
||||
or payload.get('content')
|
||||
or ''
|
||||
)
|
||||
if piece is None:
|
||||
return ''
|
||||
return str(piece)
|
||||
|
||||
|
||||
def _sse_event(obj):
|
||||
import json
|
||||
return 'data: %s\n\n' % json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
|
||||
async def _iter_upstream_stream(api_url, api_key, payload):
|
||||
"""向上游发起流式请求,逐片 yield 文本"""
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer %s' % api_key,
|
||||
}
|
||||
payload = dict(payload)
|
||||
payload['stream'] = True
|
||||
|
||||
buffer = ''
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600)) as session:
|
||||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
err_text = await response.text()
|
||||
yield {'type': 'error', 'msg': 'HTTP %s: %s' % (response.status, err_text[:500])}
|
||||
return
|
||||
|
||||
async for raw in response.content:
|
||||
buffer += raw.decode('utf-8', errors='ignore')
|
||||
while '\n' in buffer:
|
||||
line, buffer = buffer.split('\n', 1)
|
||||
line = line.strip()
|
||||
if not line or line.startswith(':') or not line.startswith('data:'):
|
||||
continue
|
||||
data = line[5:].strip()
|
||||
if data == '[DONE]':
|
||||
return
|
||||
try:
|
||||
payload_obj = json.loads(data)
|
||||
piece = _extract_stream_piece(payload_obj)
|
||||
if piece:
|
||||
yield {'type': 'content', 'content': piece}
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
tail = buffer.strip()
|
||||
if tail:
|
||||
try:
|
||||
body = json.loads(tail)
|
||||
choice = (body.get('choices') or [{}])[0]
|
||||
msg = choice.get('message') or {}
|
||||
piece = msg.get('content') or choice.get('text') or ''
|
||||
if piece:
|
||||
yield {'type': 'content', 'content': str(piece)}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def inference_generator(request, params_kw=None, **kw):
|
||||
"""
|
||||
SSE 流式输出,事件格式:
|
||||
{"type":"meta","model":"..."}
|
||||
{"type":"content","content":"片段"}
|
||||
{"type":"done","reply":"完整文本"}
|
||||
{"type":"error","msg":"..."}
|
||||
结束:data: [DONE]
|
||||
"""
|
||||
import traceback
|
||||
|
||||
ns = params_kw or {}
|
||||
model = ns.get('model')
|
||||
if not model:
|
||||
yield _sse_event({'type': 'error', 'msg': 'model is required'})
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
|
||||
history = _parse_messages(ns)
|
||||
user_content = build_user_content(ns)
|
||||
if not user_content and not history:
|
||||
yield _sse_event({'type': 'error', 'msg': 'message is required'})
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
|
||||
messages = list(history)
|
||||
if user_content:
|
||||
messages.append({'role': 'user', 'content': user_content})
|
||||
|
||||
db = DBPools()
|
||||
async with db.sqlorContext('kboss') as sor:
|
||||
try:
|
||||
api_url, api_key = await _resolve_chat_config(ns, sor)
|
||||
if not api_key:
|
||||
yield _sse_event({'type': 'error', 'msg': '未找到 API Key'})
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
|
||||
yield _sse_event({'type': 'meta', 'model': model, 'stream': True})
|
||||
|
||||
parts = []
|
||||
async for evt in _iter_upstream_stream(api_url, api_key, {
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
}):
|
||||
if evt.get('type') == 'error':
|
||||
yield _sse_event(evt)
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
if evt.get('type') == 'content':
|
||||
parts.append(evt['content'])
|
||||
yield _sse_event(evt)
|
||||
|
||||
reply = ''.join(parts)
|
||||
yield _sse_event({'type': 'done', 'reply': reply, 'model': model})
|
||||
yield 'data: [DONE]\n\n'
|
||||
except Exception:
|
||||
yield _sse_event({'type': 'error', 'msg': traceback.format_exc()})
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
|
||||
async def inference(request, *args, params_kw=None, **kw):
|
||||
from functools import partial
|
||||
env = request._run_ns.copy()
|
||||
f = partial(inference_generator, request, params_kw=params_kw, **kw)
|
||||
return await env.stream_response(request, f, content_type='text/event-stream')
|
||||
|
||||
|
||||
ret = await inference(request, params_kw=params_kw)
|
||||
return ret
|
||||
247
b/cntoai/test_demo.py
Normal file
247
b/cntoai/test_demo.py
Normal file
@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
chat_send_stream.dspy SSE 流式接口测试
|
||||
|
||||
用法:
|
||||
pip install requests
|
||||
python test_demo.py
|
||||
|
||||
环境变量(可选,覆盖下方默认值):
|
||||
CNTOAI_BASE_URL / CNTOAI_USERID / CNTOAI_API_KEY
|
||||
CNTOAI_MODEL / CNTOAI_LLM_API_URL / CNTOAI_MESSAGE
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
print("请先安装: pip install requests")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
BASE_URL = os.environ.get("CNTOAI_BASE_URL", "https://dev.opencomputing.cn").rstrip("/")
|
||||
USERID = os.environ.get("CNTOAI_USERID", "hSqZuekZ1yKmhKmCN9UAK").strip()
|
||||
API_KEY = os.environ.get("CNTOAI_API_KEY", "sk-c22d6573e85a4d3fa8ab932386cf2909").strip()
|
||||
API_URL = os.environ.get("CNTOAI_LLM_API_URL", "https://api.deepseek.com/v1/chat/completions").strip()
|
||||
MODEL = os.environ.get("CNTOAI_MODEL", "deepseek-chat").strip()
|
||||
MESSAGE = os.environ.get("CNTOAI_MESSAGE", "你好,请用三句话介绍你自己").strip()
|
||||
TIMEOUT = int(os.environ.get("CNTOAI_TIMEOUT", "300"))
|
||||
|
||||
STREAM_PATH = "/cntoai/chat_send_stream.dspy"
|
||||
|
||||
|
||||
def build_payload(session_id: Optional[str] = None, message: Optional[str] = None) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"model": MODEL,
|
||||
"message": message or MESSAGE,
|
||||
"userid": USERID,
|
||||
"api_key": API_KEY,
|
||||
"api_url": API_URL,
|
||||
}
|
||||
if session_id:
|
||||
payload["session_id"] = session_id
|
||||
return payload
|
||||
|
||||
|
||||
def parse_sse_text(text: str) -> List[Dict[str, Any]]:
|
||||
events: List[Dict[str, Any]] = []
|
||||
for line in text.splitlines():
|
||||
line = line.strip()
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
data = line[5:].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
events.append(json.loads(data))
|
||||
except json.JSONDecodeError:
|
||||
print(f"[warn] 无法解析: {line[:200]}")
|
||||
return events
|
||||
|
||||
|
||||
def parse_sse_stream(response: requests.Response) -> Generator[Dict[str, Any], None, None]:
|
||||
"""
|
||||
按字节缓冲解析 SSE。
|
||||
勿用 iter_lines(decode_unicode=True):TCP 分块可能截断 UTF-8 多字节字符,导致乱码和 JSON 解析失败。
|
||||
"""
|
||||
buffer = b""
|
||||
for chunk in response.iter_content(chunk_size=4096):
|
||||
if not chunk:
|
||||
continue
|
||||
buffer += chunk
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
if not line_bytes.strip():
|
||||
continue
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
data = line[5:].strip()
|
||||
if data == "[DONE]":
|
||||
return
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
print(f"\n[warn] JSON 解析失败: {line[:120]}...")
|
||||
|
||||
tail = buffer.strip()
|
||||
if tail:
|
||||
line = tail.decode("utf-8", errors="replace").strip()
|
||||
if line.startswith("data:"):
|
||||
data = line[5:].strip()
|
||||
if data and data != "[DONE]":
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
def diagnose_empty_response(resp: requests.Response) -> None:
|
||||
ctype = resp.headers.get("Content-Type", "")
|
||||
body = resp.content or b""
|
||||
print("\n[诊断] 响应体为空或无可解析 SSE")
|
||||
print(f" Content-Type : {ctype}")
|
||||
print(f" body 长度 : {len(body)}")
|
||||
if body:
|
||||
print(f" body 前 500B : {body[:500]!r}")
|
||||
if "text/html" in ctype and len(body) == 0:
|
||||
print("\n 可能原因: chat_send_stream.dspy 未执行 inference 入口。")
|
||||
print(" 请确认文件末尾包含:")
|
||||
print(" ret = await inference(request, params_kw=params_kw)")
|
||||
print(" return ret")
|
||||
print(" 并重新部署到 dev 后再测。")
|
||||
|
||||
|
||||
def test_chat_send_stream(session_id: Optional[str] = None, message: Optional[str] = None) -> Optional[str]:
|
||||
url = BASE_URL + STREAM_PATH
|
||||
payload = build_payload(session_id=session_id, message=message)
|
||||
|
||||
print("=" * 60)
|
||||
print("chat_send_stream.dspy 流式测试")
|
||||
print(f" URL : {url}")
|
||||
print(f" MODEL : {MODEL}")
|
||||
print(f" USERID : {USERID}")
|
||||
print(f" API_URL : {API_URL}")
|
||||
print(f" message : {payload.get('message')}")
|
||||
if session_id:
|
||||
print(f" session : {session_id}")
|
||||
print("=" * 60)
|
||||
|
||||
if not USERID:
|
||||
print("错误: 请设置 CNTOAI_USERID")
|
||||
return None
|
||||
|
||||
resp = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Accept": "text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
stream=True,
|
||||
timeout=TIMEOUT,
|
||||
)
|
||||
|
||||
ctype = resp.headers.get("Content-Type", "")
|
||||
print(f"\nHTTP {resp.status_code} Content-Type: {ctype}\n")
|
||||
|
||||
if resp.status_code != 200:
|
||||
print(resp.text[:500])
|
||||
return None
|
||||
|
||||
if "text/event-stream" not in ctype:
|
||||
raw = resp.content
|
||||
diagnose_empty_response(resp)
|
||||
if raw:
|
||||
for evt in parse_sse_text(raw.decode("utf-8", errors="ignore")):
|
||||
print("[parsed]", evt)
|
||||
return None
|
||||
|
||||
session_out: Optional[str] = session_id
|
||||
full_reply: List[str] = []
|
||||
has_content = False
|
||||
event_count = 0
|
||||
|
||||
print("--- 流式输出 ---")
|
||||
for evt in parse_sse_stream(resp):
|
||||
event_count += 1
|
||||
etype = evt.get("type")
|
||||
|
||||
if etype == "meta":
|
||||
session_out = evt.get("session_id") or session_out
|
||||
print(f"[meta] session_id={session_out} model={evt.get('model')}")
|
||||
continue
|
||||
|
||||
if etype == "content":
|
||||
piece = evt.get("content") or ""
|
||||
has_content = True
|
||||
full_reply.append(piece)
|
||||
print(piece, end="", flush=True)
|
||||
continue
|
||||
|
||||
if etype == "done":
|
||||
session_out = evt.get("session_id") or session_out
|
||||
reply = evt.get("reply") or ""
|
||||
print(f"\n\n[done] session_id={session_out}")
|
||||
print(f"[done] reply 长度={len(reply)}")
|
||||
if reply and not has_content:
|
||||
print(reply)
|
||||
continue
|
||||
|
||||
if etype == "error":
|
||||
print(f"\n[error] {evt.get('msg')}")
|
||||
return session_out
|
||||
|
||||
print(f"\n[unknown] {evt}")
|
||||
|
||||
print("\n--- 结束 ---")
|
||||
if event_count == 0:
|
||||
diagnose_empty_response(resp)
|
||||
elif full_reply:
|
||||
joined = "".join(full_reply)
|
||||
print(f"拼接回复({len(joined)}字): {joined[:300]}...")
|
||||
return session_out
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser(description="chat_send_stream.dspy SSE 测试")
|
||||
parser.add_argument("--session-id", help="续聊会话 ID")
|
||||
parser.add_argument("--message", "-m", help="覆盖默认 message")
|
||||
parser.add_argument("--twice", action="store_true", help="同一会话连发两条")
|
||||
args = parser.parse_args()
|
||||
|
||||
sid = test_chat_send_stream(session_id=args.session_id, message=args.message)
|
||||
if sid is None:
|
||||
return 1
|
||||
|
||||
if args.twice and sid:
|
||||
print("\n" + "=" * 60)
|
||||
print("第二轮(多轮续聊)")
|
||||
sid2 = test_chat_send_stream(
|
||||
session_id=sid,
|
||||
message=args.message or "继续,用一句话总结上面内容",
|
||||
)
|
||||
if sid2 is None:
|
||||
return 1
|
||||
|
||||
if sid:
|
||||
print(f"\n提示: 续聊 python test_demo.py --session-id {sid}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Loading…
x
Reference in New Issue
Block a user