178 lines
6.0 KiB
Plaintext
178 lines
6.0 KiB
Plaintext
def _escape(value):
|
||
if value is None:
|
||
return None
|
||
return str(value).replace("'", "''")
|
||
|
||
|
||
def _parse_bool(value, default=True):
|
||
if value is None or value == '':
|
||
return default
|
||
if isinstance(value, bool):
|
||
return value
|
||
return str(value).lower() in ('1', 'true', 'yes', 'on')
|
||
|
||
|
||
def _title_from_message(ns):
|
||
text = ns.get('message') or ns.get('text') or ''
|
||
text = str(text).strip().replace('\n', ' ')
|
||
if not text:
|
||
return '新对话'
|
||
return text[:30] + ('...' if len(text) > 30 else '')
|
||
|
||
|
||
def _build_user_content(ns):
|
||
text_parts = []
|
||
if ns.get('message'):
|
||
text_parts.append(str(ns.get('message')))
|
||
if ns.get('text'):
|
||
text_parts.append(str(ns.get('text')))
|
||
if ns.get('document_text'):
|
||
text_parts.append(str(ns.get('document_text')))
|
||
|
||
parts = []
|
||
merged_text = '\n'.join([p for p in text_parts if p]).strip()
|
||
if merged_text:
|
||
parts.append({'type': 'text', 'text': merged_text})
|
||
if ns.get('image_url'):
|
||
parts.append({'type': 'image_url', 'image_url': {'url': ns.get('image_url')}})
|
||
if ns.get('image_base64'):
|
||
mime = ns.get('image_mime') or 'image/jpeg'
|
||
b64 = ns.get('image_base64')
|
||
if not str(b64).startswith('data:'):
|
||
b64 = 'data:%s;base64,%s' % (mime, b64)
|
||
parts.append({'type': 'image_url', 'image_url': {'url': b64}})
|
||
if ns.get('document_url'):
|
||
parts.append({'type': 'file', 'file': {'file_url': ns.get('document_url')}})
|
||
if not parts:
|
||
return ''
|
||
if len(parts) == 1 and parts[0]['type'] == 'text':
|
||
return parts[0]['text']
|
||
return parts
|
||
|
||
|
||
async def _load_session_messages(sor, session_id):
|
||
sql = """
|
||
SELECT role, content, content_type
|
||
FROM chat_message
|
||
WHERE session_id = '%s'
|
||
ORDER BY created_at ASC;
|
||
""" % _escape(session_id)
|
||
rows = await sor.sqlExe(sql, {})
|
||
messages = []
|
||
for row in rows:
|
||
content = row.get('content') or ''
|
||
if row.get('content_type') == 'mixed':
|
||
import json
|
||
try:
|
||
content = json.loads(content)
|
||
except Exception:
|
||
pass
|
||
messages.append({'role': row['role'], 'content': content})
|
||
return messages
|
||
|
||
|
||
async def chat_send(ns={}):
|
||
"""
|
||
发送消息并保存多轮对话(需先执行 chat_tables.sql)。
|
||
|
||
参数:model, message, stream(默认true), session_id,
|
||
image_url, image_base64, document_url, document_text,
|
||
with_chunks(true时返回上游 SSE 分片列表,便于确认流式)
|
||
|
||
说明:本接口(chat_send.dspy)为 JSON 一次性返回。
|
||
需要浏览器端实时流式请调用 chat_send_stream.dspy(SSE)。
|
||
"""
|
||
import json
|
||
import traceback
|
||
|
||
# model = ns.get('model')
|
||
model = 'deepseek-v4-pro'
|
||
if not model:
|
||
return {'status': False, 'msg': 'model is required'}
|
||
|
||
userid = ns.get('userid') or await get_user()
|
||
if not userid:
|
||
return {'status': False, 'msg': '未找到用户'}
|
||
|
||
user_content = _build_user_content(ns)
|
||
if not user_content:
|
||
return {'status': False, 'msg': '请输入文本,或提供图片/文档参数'}
|
||
|
||
content_type = 'mixed' if isinstance(user_content, list) else 'text'
|
||
store_content = json.dumps(user_content, ensure_ascii=False) if content_type == 'mixed' else str(user_content)
|
||
|
||
db = DBPools()
|
||
async with db.sqlorContext('kboss') as sor:
|
||
try:
|
||
session_id = ns.get('session_id')
|
||
if not session_id:
|
||
session_id = uuid()
|
||
await sor.C('chat_session', {
|
||
'id': session_id,
|
||
'userid': userid,
|
||
'model': model,
|
||
'title': _title_from_message(ns),
|
||
})
|
||
else:
|
||
sessions = await sor.R('chat_session', {'id': session_id, 'userid': userid})
|
||
if not sessions:
|
||
return {'status': False, 'msg': '会话不存在'}
|
||
|
||
await sor.C('chat_message', {
|
||
'id': uuid(),
|
||
'session_id': session_id,
|
||
'role': 'user',
|
||
'content': store_content,
|
||
'content_type': content_type,
|
||
})
|
||
|
||
history = await _load_session_messages(sor, session_id)
|
||
stream_val = _parse_bool(ns.get('stream'), True)
|
||
chat_result = await path_call('llm_chat_completions.dspy', {
|
||
'model': model,
|
||
'messages': history,
|
||
'stream': stream_val,
|
||
'userid': userid,
|
||
'api_url': ns.get('api_url'),
|
||
'api_key': ns.get('api_key'),
|
||
'model_id': ns.get('model_id'),
|
||
'with_chunks': ns.get('with_chunks', True),
|
||
})
|
||
if not chat_result.get('status'):
|
||
return chat_result
|
||
|
||
reply = chat_result['data']['reply']
|
||
chunks = chat_result['data'].get('chunks') or []
|
||
chunk_count = chat_result['data'].get('chunk_count', 0)
|
||
await sor.C('chat_message', {
|
||
'id': uuid(),
|
||
'session_id': session_id,
|
||
'role': 'assistant',
|
||
'content': reply,
|
||
'content_type': 'text',
|
||
})
|
||
await sor.sqlExe(
|
||
"UPDATE chat_session SET updated_at = NOW() WHERE id = '%s';"
|
||
% _escape(session_id),
|
||
{},
|
||
)
|
||
|
||
return {
|
||
'status': True,
|
||
'msg': 'send success',
|
||
'data': {
|
||
'session_id': session_id,
|
||
'reply': reply,
|
||
'model': model,
|
||
'stream': stream_val,
|
||
'chunk_count': chunk_count,
|
||
'chunks': chunks if ns.get('with_chunks', True) else None,
|
||
},
|
||
}
|
||
except Exception:
|
||
return {'status': False, 'msg': 'send failed, %s' % traceback.format_exc()}
|
||
|
||
|
||
ret = await chat_send(params_kw)
|
||
return ret
|