kboss/b/cntoai/chat_send_stream.dspy
2026-06-12 09:43:11 +08:00

322 lines
11 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

def _escape(value):
if value is None:
return None
return str(value).replace("'", "''")
def _parse_bool(value, default=True):
if value is None or value == '':
return default
if isinstance(value, bool):
return value
return str(value).lower() in ('1', 'true', 'yes', 'on')
def _title_from_message(ns):
text = ns.get('message') or ns.get('text') or ''
text = str(text).strip().replace('\n', ' ')
if not text:
return '新对话'
return text[:30] + ('...' if len(text) > 30 else '')
MODEL_INFO_PROMPT = '凡是提问模型信息的,都回答:这是开元平台提供的大模型。'
def _build_user_content(ns):
text_parts = []
if ns.get('message'):
text_parts.append(str(ns.get('message')))
if ns.get('text'):
text_parts.append(str(ns.get('text')))
if ns.get('document_text'):
text_parts.append(str(ns.get('document_text')))
parts = []
merged_text = '\n'.join([p for p in text_parts if p]).strip()
if merged_text:
parts.append({'type': 'text', 'text': merged_text})
if ns.get('image_url'):
parts.append({'type': 'image_url', 'image_url': {'url': ns.get('image_url')}})
if ns.get('image_base64'):
mime = ns.get('image_mime') or 'image/jpeg'
b64 = ns.get('image_base64')
if not str(b64).startswith('data:'):
b64 = 'data:%s;base64,%s' % (mime, b64)
parts.append({'type': 'image_url', 'image_url': {'url': b64}})
if ns.get('document_url'):
parts.append({'type': 'file', 'file': {'file_url': ns.get('document_url')}})
if not parts:
return ''
if len(parts) == 1 and parts[0]['type'] == 'text':
return parts[0]['text']
return parts
async def _load_session_messages(sor, session_id):
sql = """
SELECT role, content, content_type
FROM chat_message
WHERE session_id = '%s'
ORDER BY created_at ASC;
""" % _escape(session_id)
rows = await sor.sqlExe(sql, {})
messages = []
for row in rows:
content = row.get('content') or ''
if row.get('content_type') == 'mixed':
import json
try:
content = json.loads(content)
except Exception:
pass
messages.append({'role': row['role'], 'content': content})
return messages
async def _resolve_chat_config(ns, sor):
# api_url = ns.get('api_url')
# api_key = ns.get('api_key')
api_url = 'https://api.deepseek.com/chat/completions'
api_key = 'sk-c22d6573e85a4d3fa8ab932386cf2909'
# api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions'
# api_key = 'jYq8_ye1lZMCTJLz22Pcd'
if not api_url and ns.get('model_id'):
doc_rows = await sor.sqlExe(
"SELECT api_url FROM model_api_doc WHERE model_id = '%s' LIMIT 1;"
% _escape(ns.get('model_id')),
{},
)
if doc_rows and doc_rows[0].get('api_url'):
api_url = doc_rows[0]['api_url']
if not str(api_url).endswith('/chat/completions'):
api_url = str(api_url).rstrip('/') + '/chat/completions'
if not api_url:
param_rows = await sor.R('params', {'pname': 'cntoai_llm_chat_url'})
if param_rows:
api_url = param_rows[0]['pvalue']
else:
domain_rows = await sor.R('params', {'pname': 'cntoai_domain'})
if domain_rows:
api_url = domain_rows[0]['pvalue'].rstrip('/') + '/llmage/v1/chat/completions'
else:
api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions'
if not api_key:
userid = ns.get('userid') or await get_user()
if userid:
action = ns.get('apikey_action') or 'user_self_create'
keys = await sor.R('user_api_keys', {'userid': userid, 'action': action})
if not keys:
keys = await sor.R('user_api_keys', {'userid': userid, 'action': 'sync'})
if keys:
api_key = keys[0].get('opc_apikey')
if not api_key:
key_rows = await sor.R('params', {'pname': 'cntoai_llm_api_key'})
if key_rows:
api_key = key_rows[0]['pvalue']
return api_url, api_key
def _extract_stream_piece(payload):
choice = (payload.get('choices') or [{}])[0]
delta = choice.get('delta') or {}
message = choice.get('message') or {}
piece = (
delta.get('content')
or delta.get('reasoning_content')
or message.get('content')
or choice.get('text')
or payload.get('content')
or ''
)
if piece is None:
return ''
return str(piece)
def _sse_event(obj):
import json
return 'data: %s\n\n' % json.dumps(obj, ensure_ascii=False)
async def _iter_upstream_stream(api_url, api_key, payload):
import aiohttp
import json
headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer %s' % api_key,
}
payload = dict(payload)
payload['stream'] = True
buffer = ''
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600)) as session:
async with session.post(api_url, headers=headers, json=payload) as response:
if response.status != 200:
err_text = await response.text()
yield {'type': 'error', 'msg': 'HTTP %s: %s' % (response.status, err_text[:500])}
return
async for raw in response.content:
buffer += raw.decode('utf-8', errors='ignore')
while '\n' in buffer:
line, buffer = buffer.split('\n', 1)
line = line.strip()
if not line or line.startswith(':') or not line.startswith('data:'):
continue
data = line[5:].strip()
if data == '[DONE]':
return
try:
payload_obj = json.loads(data)
piece = _extract_stream_piece(payload_obj)
if piece:
yield {'type': 'content', 'content': piece}
except Exception:
continue
tail = buffer.strip()
if tail:
try:
body = json.loads(tail)
choice = (body.get('choices') or [{}])[0]
msg = choice.get('message') or {}
piece = msg.get('content') or choice.get('text') or ''
if piece:
yield {'type': 'content', 'content': str(piece)}
except Exception:
pass
async def inference_generator(request, params_kw=None, **kw):
"""
流式 chat_send先存 user 消息SSE 推送 assistant 片段,结束后存库。
SSE 事件:
{"type":"meta","session_id":"...","model":"..."}
{"type":"content","content":"片段"}
{"type":"done","session_id":"...","reply":"完整文本","model":"..."}
{"type":"error","msg":"..."}
"""
import json
import traceback
ns = params_kw or {}
# model = ns.get('model')
model = 'deepseek-v4-pro'
# model = 'qwen3.6-plus'
if not model:
yield _sse_event({'type': 'error', 'msg': 'model is required'})
yield 'data: [DONE]\n\n'
return
userid = ns.get('userid') or await get_user()
if not userid:
yield _sse_event({'type': 'error', 'msg': '未找到用户'})
yield 'data: [DONE]\n\n'
return
user_content = _build_user_content(ns)
if not user_content:
yield _sse_event({'type': 'error', 'msg': '请输入文本,或提供图片/文档参数'})
yield 'data: [DONE]\n\n'
return
content_type = 'mixed' if isinstance(user_content, list) else 'text'
store_content = json.dumps(user_content, ensure_ascii=False) if content_type == 'mixed' else str(user_content)
db = DBPools()
async with db.sqlorContext('kboss') as sor:
try:
session_id = ns.get('session_id')
if not session_id:
session_id = uuid()
await sor.C('chat_session', {
'id': session_id,
'userid': userid,
'model': model,
'title': _title_from_message(ns),
})
else:
sessions = await sor.R('chat_session', {'id': session_id, 'userid': userid})
if not sessions:
yield _sse_event({'type': 'error', 'msg': '会话不存在'})
yield 'data: [DONE]\n\n'
return
await sor.C('chat_message', {
'id': uuid(),
'session_id': session_id,
'role': 'user',
'content': store_content,
'content_type': content_type,
})
history = await _load_session_messages(sor, session_id)
api_url, api_key = await _resolve_chat_config(ns, sor)
if not api_key:
yield _sse_event({'type': 'error', 'msg': '未找到 API Key'})
yield 'data: [DONE]\n\n'
return
yield _sse_event({
'type': 'meta',
'session_id': session_id,
'model': model,
'stream': True,
})
parts = []
async for evt in _iter_upstream_stream(api_url, api_key, {
'model': model,
'messages': [{'role': 'system', 'content': MODEL_INFO_PROMPT}] + history,
}):
if evt.get('type') == 'error':
yield _sse_event(evt)
yield 'data: [DONE]\n\n'
return
if evt.get('type') == 'content':
parts.append(evt['content'])
yield _sse_event(evt)
reply = ''.join(parts)
await sor.C('chat_message', {
'id': uuid(),
'session_id': session_id,
'role': 'assistant',
'content': reply,
'content_type': 'text',
})
await sor.sqlExe(
"UPDATE chat_session SET updated_at = NOW() WHERE id = '%s';"
% _escape(session_id),
{},
)
yield _sse_event({
'type': 'done',
'session_id': session_id,
'reply': reply,
'model': model,
})
yield 'data: [DONE]\n\n'
except Exception:
yield _sse_event({'type': 'error', 'msg': traceback.format_exc()})
yield 'data: [DONE]\n\n'
async def inference(request, *args, params_kw=None, **kw):
from functools import partial
env = request._run_ns.copy()
f = partial(inference_generator, request, params_kw=params_kw, **kw)
return await env.stream_response(request, f, content_type='text/event-stream')
ret = await inference(request, params_kw=params_kw)
return ret