kboss/b/cntoai/llm_chat_completions_stream.dspy
2026-05-23 14:22:30 +08:00

242 lines
8.2 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

def _escape(value):
if value is None:
return None
return str(value).replace("'", "''")
def _parse_bool(value, default=True):
if value is None or value == '':
return default
if isinstance(value, bool):
return value
return str(value).lower() in ('1', 'true', 'yes', 'on')
def _parse_messages(ns):
raw = ns.get('messages')
if not raw:
return []
if isinstance(raw, list):
return raw
if isinstance(raw, str):
import json
try:
return json.loads(raw)
except Exception:
return []
return []
def build_user_content(ns):
text_parts = []
if ns.get('message'):
text_parts.append(str(ns.get('message')))
if ns.get('text'):
text_parts.append(str(ns.get('text')))
if ns.get('document_text'):
text_parts.append(str(ns.get('document_text')))
parts = []
merged_text = '\n'.join([p for p in text_parts if p]).strip()
if merged_text:
parts.append({'type': 'text', 'text': merged_text})
if ns.get('image_url'):
parts.append({'type': 'image_url', 'image_url': {'url': ns.get('image_url')}})
if ns.get('image_base64'):
mime = ns.get('image_mime') or 'image/jpeg'
b64 = ns.get('image_base64')
if not str(b64).startswith('data:'):
b64 = 'data:%s;base64,%s' % (mime, b64)
parts.append({'type': 'image_url', 'image_url': {'url': b64}})
if ns.get('document_url'):
parts.append({'type': 'file', 'file': {'file_url': ns.get('document_url')}})
if not parts:
return ''
if len(parts) == 1 and parts[0]['type'] == 'text':
return parts[0]['text']
return parts
async def _resolve_chat_config(ns, sor):
api_url = ns.get('api_url')
api_key = ns.get('api_key')
if not api_url and ns.get('model_id'):
doc_rows = await sor.sqlExe(
"SELECT api_url FROM model_api_doc WHERE model_id = '%s' LIMIT 1;"
% _escape(ns.get('model_id')),
{},
)
if doc_rows and doc_rows[0].get('api_url'):
api_url = doc_rows[0]['api_url']
if not str(api_url).endswith('/chat/completions'):
api_url = str(api_url).rstrip('/') + '/chat/completions'
if not api_url:
param_rows = await sor.R('params', {'pname': 'cntoai_llm_chat_url'})
if param_rows:
api_url = param_rows[0]['pvalue']
else:
domain_rows = await sor.R('params', {'pname': 'cntoai_domain'})
if domain_rows:
api_url = domain_rows[0]['pvalue'].rstrip('/') + '/llmage/v1/chat/completions'
else:
api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions'
if not api_key:
userid = ns.get('userid') or await get_user()
if userid:
action = ns.get('apikey_action') or 'user_self_create'
keys = await sor.R('user_api_keys', {'userid': userid, 'action': action})
if not keys:
keys = await sor.R('user_api_keys', {'userid': userid, 'action': 'sync'})
if keys:
api_key = keys[0].get('opc_apikey')
if not api_key:
key_rows = await sor.R('params', {'pname': 'cntoai_llm_api_key'})
if key_rows:
api_key = key_rows[0]['pvalue']
return api_url, api_key
def _extract_stream_piece(payload):
choice = (payload.get('choices') or [{}])[0]
delta = choice.get('delta') or {}
message = choice.get('message') or {}
piece = (
delta.get('content')
or delta.get('reasoning_content')
or message.get('content')
or choice.get('text')
or payload.get('content')
or ''
)
if piece is None:
return ''
return str(piece)
def _sse_event(obj):
import json
return 'data: %s\n\n' % json.dumps(obj, ensure_ascii=False)
async def _iter_upstream_stream(api_url, api_key, payload):
"""向上游发起流式请求,逐片 yield 文本"""
import aiohttp
import json
headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer %s' % api_key,
}
payload = dict(payload)
payload['stream'] = True
buffer = ''
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600)) as session:
async with session.post(api_url, headers=headers, json=payload) as response:
if response.status != 200:
err_text = await response.text()
yield {'type': 'error', 'msg': 'HTTP %s: %s' % (response.status, err_text[:500])}
return
async for raw in response.content:
buffer += raw.decode('utf-8', errors='ignore')
while '\n' in buffer:
line, buffer = buffer.split('\n', 1)
line = line.strip()
if not line or line.startswith(':') or not line.startswith('data:'):
continue
data = line[5:].strip()
if data == '[DONE]':
return
try:
payload_obj = json.loads(data)
piece = _extract_stream_piece(payload_obj)
if piece:
yield {'type': 'content', 'content': piece}
except Exception:
continue
tail = buffer.strip()
if tail:
try:
body = json.loads(tail)
choice = (body.get('choices') or [{}])[0]
msg = choice.get('message') or {}
piece = msg.get('content') or choice.get('text') or ''
if piece:
yield {'type': 'content', 'content': str(piece)}
except Exception:
pass
async def inference_generator(request, params_kw=None, **kw):
"""
SSE 流式输出,事件格式:
{"type":"meta","model":"..."}
{"type":"content","content":"片段"}
{"type":"done","reply":"完整文本"}
{"type":"error","msg":"..."}
结束data: [DONE]
"""
import traceback
ns = params_kw or {}
model = ns.get('model')
if not model:
yield _sse_event({'type': 'error', 'msg': 'model is required'})
yield 'data: [DONE]\n\n'
return
history = _parse_messages(ns)
user_content = build_user_content(ns)
if not user_content and not history:
yield _sse_event({'type': 'error', 'msg': 'message is required'})
yield 'data: [DONE]\n\n'
return
messages = list(history)
if user_content:
messages.append({'role': 'user', 'content': user_content})
db = DBPools()
async with db.sqlorContext('kboss') as sor:
try:
api_url, api_key = await _resolve_chat_config(ns, sor)
if not api_key:
yield _sse_event({'type': 'error', 'msg': '未找到 API Key'})
yield 'data: [DONE]\n\n'
return
yield _sse_event({'type': 'meta', 'model': model, 'stream': True})
parts = []
async for evt in _iter_upstream_stream(api_url, api_key, {
'model': model,
'messages': messages,
}):
if evt.get('type') == 'error':
yield _sse_event(evt)
yield 'data: [DONE]\n\n'
return
if evt.get('type') == 'content':
parts.append(evt['content'])
yield _sse_event(evt)
reply = ''.join(parts)
yield _sse_event({'type': 'done', 'reply': reply, 'model': model})
yield 'data: [DONE]\n\n'
except Exception:
yield _sse_event({'type': 'error', 'msg': traceback.format_exc()})
yield 'data: [DONE]\n\n'
async def inference(request, *args, params_kw=None, **kw):
from functools import partial
env = request._run_ns.copy()
f = partial(inference_generator, request, params_kw=params_kw, **kw)
return await env.stream_response(request, f, content_type='text/event-stream')
ret = await inference(request, params_kw=params_kw)
return ret