322 lines
11 KiB
Plaintext
322 lines
11 KiB
Plaintext
def _escape(value):
|
||
if value is None:
|
||
return None
|
||
return str(value).replace("'", "''")
|
||
|
||
|
||
def _parse_bool(value, default=True):
|
||
if value is None or value == '':
|
||
return default
|
||
if isinstance(value, bool):
|
||
return value
|
||
return str(value).lower() in ('1', 'true', 'yes', 'on')
|
||
|
||
|
||
def _title_from_message(ns):
|
||
text = ns.get('message') or ns.get('text') or ''
|
||
text = str(text).strip().replace('\n', ' ')
|
||
if not text:
|
||
return '新对话'
|
||
return text[:30] + ('...' if len(text) > 30 else '')
|
||
|
||
|
||
MODEL_INFO_PROMPT = '凡是提问模型信息的,都回答:这是开元平台提供的大模型。'
|
||
|
||
|
||
def _build_user_content(ns):
|
||
text_parts = []
|
||
if ns.get('message'):
|
||
text_parts.append(str(ns.get('message')))
|
||
if ns.get('text'):
|
||
text_parts.append(str(ns.get('text')))
|
||
if ns.get('document_text'):
|
||
text_parts.append(str(ns.get('document_text')))
|
||
|
||
parts = []
|
||
merged_text = '\n'.join([p for p in text_parts if p]).strip()
|
||
if merged_text:
|
||
parts.append({'type': 'text', 'text': merged_text})
|
||
if ns.get('image_url'):
|
||
parts.append({'type': 'image_url', 'image_url': {'url': ns.get('image_url')}})
|
||
if ns.get('image_base64'):
|
||
mime = ns.get('image_mime') or 'image/jpeg'
|
||
b64 = ns.get('image_base64')
|
||
if not str(b64).startswith('data:'):
|
||
b64 = 'data:%s;base64,%s' % (mime, b64)
|
||
parts.append({'type': 'image_url', 'image_url': {'url': b64}})
|
||
if ns.get('document_url'):
|
||
parts.append({'type': 'file', 'file': {'file_url': ns.get('document_url')}})
|
||
if not parts:
|
||
return ''
|
||
if len(parts) == 1 and parts[0]['type'] == 'text':
|
||
return parts[0]['text']
|
||
return parts
|
||
|
||
|
||
async def _load_session_messages(sor, session_id):
|
||
sql = """
|
||
SELECT role, content, content_type
|
||
FROM chat_message
|
||
WHERE session_id = '%s'
|
||
ORDER BY created_at ASC;
|
||
""" % _escape(session_id)
|
||
rows = await sor.sqlExe(sql, {})
|
||
messages = []
|
||
for row in rows:
|
||
content = row.get('content') or ''
|
||
if row.get('content_type') == 'mixed':
|
||
import json
|
||
try:
|
||
content = json.loads(content)
|
||
except Exception:
|
||
pass
|
||
messages.append({'role': row['role'], 'content': content})
|
||
return messages
|
||
|
||
|
||
async def _resolve_chat_config(ns, sor):
|
||
# api_url = ns.get('api_url')
|
||
# api_key = ns.get('api_key')
|
||
api_url = 'https://api.deepseek.com/chat/completions'
|
||
api_key = 'sk-c22d6573e85a4d3fa8ab932386cf2909'
|
||
|
||
|
||
# api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions'
|
||
# api_key = 'jYq8_ye1lZMCTJLz22Pcd'
|
||
|
||
if not api_url and ns.get('model_id'):
|
||
doc_rows = await sor.sqlExe(
|
||
"SELECT api_url FROM model_api_doc WHERE model_id = '%s' LIMIT 1;"
|
||
% _escape(ns.get('model_id')),
|
||
{},
|
||
)
|
||
if doc_rows and doc_rows[0].get('api_url'):
|
||
api_url = doc_rows[0]['api_url']
|
||
if not str(api_url).endswith('/chat/completions'):
|
||
api_url = str(api_url).rstrip('/') + '/chat/completions'
|
||
if not api_url:
|
||
param_rows = await sor.R('params', {'pname': 'cntoai_llm_chat_url'})
|
||
if param_rows:
|
||
api_url = param_rows[0]['pvalue']
|
||
else:
|
||
domain_rows = await sor.R('params', {'pname': 'cntoai_domain'})
|
||
if domain_rows:
|
||
api_url = domain_rows[0]['pvalue'].rstrip('/') + '/llmage/v1/chat/completions'
|
||
else:
|
||
api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions'
|
||
if not api_key:
|
||
userid = ns.get('userid') or await get_user()
|
||
if userid:
|
||
action = ns.get('apikey_action') or 'user_self_create'
|
||
keys = await sor.R('user_api_keys', {'userid': userid, 'action': action})
|
||
if not keys:
|
||
keys = await sor.R('user_api_keys', {'userid': userid, 'action': 'sync'})
|
||
if keys:
|
||
api_key = keys[0].get('opc_apikey')
|
||
if not api_key:
|
||
key_rows = await sor.R('params', {'pname': 'cntoai_llm_api_key'})
|
||
if key_rows:
|
||
api_key = key_rows[0]['pvalue']
|
||
return api_url, api_key
|
||
|
||
|
||
def _extract_stream_piece(payload):
|
||
choice = (payload.get('choices') or [{}])[0]
|
||
delta = choice.get('delta') or {}
|
||
message = choice.get('message') or {}
|
||
piece = (
|
||
delta.get('content')
|
||
or delta.get('reasoning_content')
|
||
or message.get('content')
|
||
or choice.get('text')
|
||
or payload.get('content')
|
||
or ''
|
||
)
|
||
if piece is None:
|
||
return ''
|
||
return str(piece)
|
||
|
||
|
||
def _sse_event(obj):
|
||
import json
|
||
return 'data: %s\n\n' % json.dumps(obj, ensure_ascii=False)
|
||
|
||
|
||
async def _iter_upstream_stream(api_url, api_key, payload):
|
||
import aiohttp
|
||
import json
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': 'Bearer %s' % api_key,
|
||
}
|
||
payload = dict(payload)
|
||
payload['stream'] = True
|
||
|
||
buffer = ''
|
||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600)) as session:
|
||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||
if response.status != 200:
|
||
err_text = await response.text()
|
||
yield {'type': 'error', 'msg': 'HTTP %s: %s' % (response.status, err_text[:500])}
|
||
return
|
||
|
||
async for raw in response.content:
|
||
buffer += raw.decode('utf-8', errors='ignore')
|
||
while '\n' in buffer:
|
||
line, buffer = buffer.split('\n', 1)
|
||
line = line.strip()
|
||
if not line or line.startswith(':') or not line.startswith('data:'):
|
||
continue
|
||
data = line[5:].strip()
|
||
if data == '[DONE]':
|
||
return
|
||
try:
|
||
payload_obj = json.loads(data)
|
||
piece = _extract_stream_piece(payload_obj)
|
||
if piece:
|
||
yield {'type': 'content', 'content': piece}
|
||
except Exception:
|
||
continue
|
||
|
||
tail = buffer.strip()
|
||
if tail:
|
||
try:
|
||
body = json.loads(tail)
|
||
choice = (body.get('choices') or [{}])[0]
|
||
msg = choice.get('message') or {}
|
||
piece = msg.get('content') or choice.get('text') or ''
|
||
if piece:
|
||
yield {'type': 'content', 'content': str(piece)}
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
async def inference_generator(request, params_kw=None, **kw):
|
||
"""
|
||
流式 chat_send:先存 user 消息,SSE 推送 assistant 片段,结束后存库。
|
||
|
||
SSE 事件:
|
||
{"type":"meta","session_id":"...","model":"..."}
|
||
{"type":"content","content":"片段"}
|
||
{"type":"done","session_id":"...","reply":"完整文本","model":"..."}
|
||
{"type":"error","msg":"..."}
|
||
"""
|
||
import json
|
||
import traceback
|
||
|
||
ns = params_kw or {}
|
||
# model = ns.get('model')
|
||
model = 'deepseek-v4-pro'
|
||
# model = 'qwen3.6-plus'
|
||
|
||
if not model:
|
||
yield _sse_event({'type': 'error', 'msg': 'model is required'})
|
||
yield 'data: [DONE]\n\n'
|
||
return
|
||
|
||
userid = ns.get('userid') or await get_user()
|
||
if not userid:
|
||
yield _sse_event({'type': 'error', 'msg': '未找到用户'})
|
||
yield 'data: [DONE]\n\n'
|
||
return
|
||
|
||
user_content = _build_user_content(ns)
|
||
if not user_content:
|
||
yield _sse_event({'type': 'error', 'msg': '请输入文本,或提供图片/文档参数'})
|
||
yield 'data: [DONE]\n\n'
|
||
return
|
||
|
||
content_type = 'mixed' if isinstance(user_content, list) else 'text'
|
||
store_content = json.dumps(user_content, ensure_ascii=False) if content_type == 'mixed' else str(user_content)
|
||
|
||
db = DBPools()
|
||
async with db.sqlorContext('kboss') as sor:
|
||
try:
|
||
session_id = ns.get('session_id')
|
||
if not session_id:
|
||
session_id = uuid()
|
||
await sor.C('chat_session', {
|
||
'id': session_id,
|
||
'userid': userid,
|
||
'model': model,
|
||
'title': _title_from_message(ns),
|
||
})
|
||
else:
|
||
sessions = await sor.R('chat_session', {'id': session_id, 'userid': userid})
|
||
if not sessions:
|
||
yield _sse_event({'type': 'error', 'msg': '会话不存在'})
|
||
yield 'data: [DONE]\n\n'
|
||
return
|
||
|
||
await sor.C('chat_message', {
|
||
'id': uuid(),
|
||
'session_id': session_id,
|
||
'role': 'user',
|
||
'content': store_content,
|
||
'content_type': content_type,
|
||
})
|
||
|
||
history = await _load_session_messages(sor, session_id)
|
||
api_url, api_key = await _resolve_chat_config(ns, sor)
|
||
if not api_key:
|
||
yield _sse_event({'type': 'error', 'msg': '未找到 API Key'})
|
||
yield 'data: [DONE]\n\n'
|
||
return
|
||
|
||
yield _sse_event({
|
||
'type': 'meta',
|
||
'session_id': session_id,
|
||
'model': model,
|
||
'stream': True,
|
||
})
|
||
|
||
parts = []
|
||
async for evt in _iter_upstream_stream(api_url, api_key, {
|
||
'model': model,
|
||
'messages': [{'role': 'system', 'content': MODEL_INFO_PROMPT}] + history,
|
||
}):
|
||
if evt.get('type') == 'error':
|
||
yield _sse_event(evt)
|
||
yield 'data: [DONE]\n\n'
|
||
return
|
||
if evt.get('type') == 'content':
|
||
parts.append(evt['content'])
|
||
yield _sse_event(evt)
|
||
|
||
reply = ''.join(parts)
|
||
await sor.C('chat_message', {
|
||
'id': uuid(),
|
||
'session_id': session_id,
|
||
'role': 'assistant',
|
||
'content': reply,
|
||
'content_type': 'text',
|
||
})
|
||
await sor.sqlExe(
|
||
"UPDATE chat_session SET updated_at = NOW() WHERE id = '%s';"
|
||
% _escape(session_id),
|
||
{},
|
||
)
|
||
|
||
yield _sse_event({
|
||
'type': 'done',
|
||
'session_id': session_id,
|
||
'reply': reply,
|
||
'model': model,
|
||
})
|
||
yield 'data: [DONE]\n\n'
|
||
except Exception:
|
||
yield _sse_event({'type': 'error', 'msg': traceback.format_exc()})
|
||
yield 'data: [DONE]\n\n'
|
||
|
||
|
||
async def inference(request, *args, params_kw=None, **kw):
|
||
from functools import partial
|
||
env = request._run_ns.copy()
|
||
f = partial(inference_generator, request, params_kw=params_kw, **kw)
|
||
return await env.stream_response(request, f, content_type='text/event-stream')
|
||
|
||
|
||
ret = await inference(request, params_kw=params_kw)
|
||
return ret
|