def _escape(value): if value is None: return None return str(value).replace("'", "''") def _parse_bool(value, default=True): if value is None or value == '': return default if isinstance(value, bool): return value return str(value).lower() in ('1', 'true', 'yes', 'on') def _parse_messages(ns): raw = ns.get('messages') if not raw: return [] if isinstance(raw, list): return raw if isinstance(raw, str): import json try: return json.loads(raw) except Exception: return [] return [] def build_user_content(ns): text_parts = [] if ns.get('message'): text_parts.append(str(ns.get('message'))) if ns.get('text'): text_parts.append(str(ns.get('text'))) if ns.get('document_text'): text_parts.append(str(ns.get('document_text'))) parts = [] merged_text = '\n'.join([p for p in text_parts if p]).strip() if merged_text: parts.append({'type': 'text', 'text': merged_text}) if ns.get('image_url'): parts.append({'type': 'image_url', 'image_url': {'url': ns.get('image_url')}}) if ns.get('image_base64'): mime = ns.get('image_mime') or 'image/jpeg' b64 = ns.get('image_base64') if not str(b64).startswith('data:'): b64 = 'data:%s;base64,%s' % (mime, b64) parts.append({'type': 'image_url', 'image_url': {'url': b64}}) if ns.get('document_url'): parts.append({'type': 'file', 'file': {'file_url': ns.get('document_url')}}) if not parts: return '' if len(parts) == 1 and parts[0]['type'] == 'text': return parts[0]['text'] return parts async def _resolve_chat_config(ns, sor): api_url = ns.get('api_url') api_key = ns.get('api_key') if not api_url and ns.get('model_id'): doc_rows = await sor.sqlExe( "SELECT api_url FROM model_api_doc WHERE model_id = '%s' LIMIT 1;" % _escape(ns.get('model_id')), {}, ) if doc_rows and doc_rows[0].get('api_url'): api_url = doc_rows[0]['api_url'] if not str(api_url).endswith('/chat/completions'): api_url = str(api_url).rstrip('/') + '/chat/completions' if not api_url: param_rows = await sor.R('params', {'pname': 'cntoai_llm_chat_url'}) if param_rows: api_url = param_rows[0]['pvalue'] else: domain_rows = await sor.R('params', {'pname': 'cntoai_domain'}) if domain_rows: api_url = domain_rows[0]['pvalue'].rstrip('/') + '/llmage/v1/chat/completions' else: api_url = 'https://ai.atvoe.com/llmage/v1/chat/completions' if not api_key: userid = ns.get('userid') or await get_user() if userid: action = ns.get('apikey_action') or 'user_self_create' keys = await sor.R('user_api_keys', {'userid': userid, 'action': action}) if not keys: keys = await sor.R('user_api_keys', {'userid': userid, 'action': 'sync'}) if keys: api_key = keys[0].get('opc_apikey') if not api_key: key_rows = await sor.R('params', {'pname': 'cntoai_llm_api_key'}) if key_rows: api_key = key_rows[0]['pvalue'] return api_url, api_key def _extract_stream_piece(payload): choice = (payload.get('choices') or [{}])[0] delta = choice.get('delta') or {} message = choice.get('message') or {} piece = ( delta.get('content') or delta.get('reasoning_content') or message.get('content') or choice.get('text') or payload.get('content') or '' ) if piece is None: return '' return str(piece) def _sse_event(obj): import json return 'data: %s\n\n' % json.dumps(obj, ensure_ascii=False) async def _iter_upstream_stream(api_url, api_key, payload): """向上游发起流式请求,逐片 yield 文本""" import aiohttp import json headers = { 'Content-Type': 'application/json', 'Authorization': 'Bearer %s' % api_key, } payload = dict(payload) payload['stream'] = True buffer = '' async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600)) as session: async with session.post(api_url, headers=headers, json=payload) as response: if response.status != 200: err_text = await response.text() yield {'type': 'error', 'msg': 'HTTP %s: %s' % (response.status, err_text[:500])} return async for raw in response.content: buffer += raw.decode('utf-8', errors='ignore') while '\n' in buffer: line, buffer = buffer.split('\n', 1) line = line.strip() if not line or line.startswith(':') or not line.startswith('data:'): continue data = line[5:].strip() if data == '[DONE]': return try: payload_obj = json.loads(data) piece = _extract_stream_piece(payload_obj) if piece: yield {'type': 'content', 'content': piece} except Exception: continue tail = buffer.strip() if tail: try: body = json.loads(tail) choice = (body.get('choices') or [{}])[0] msg = choice.get('message') or {} piece = msg.get('content') or choice.get('text') or '' if piece: yield {'type': 'content', 'content': str(piece)} except Exception: pass async def inference_generator(request, params_kw=None, **kw): """ SSE 流式输出,事件格式: {"type":"meta","model":"..."} {"type":"content","content":"片段"} {"type":"done","reply":"完整文本"} {"type":"error","msg":"..."} 结束:data: [DONE] """ import traceback ns = params_kw or {} model = ns.get('model') if not model: yield _sse_event({'type': 'error', 'msg': 'model is required'}) yield 'data: [DONE]\n\n' return history = _parse_messages(ns) user_content = build_user_content(ns) if not user_content and not history: yield _sse_event({'type': 'error', 'msg': 'message is required'}) yield 'data: [DONE]\n\n' return messages = list(history) if user_content: messages.append({'role': 'user', 'content': user_content}) db = DBPools() async with db.sqlorContext('kboss') as sor: try: api_url, api_key = await _resolve_chat_config(ns, sor) if not api_key: yield _sse_event({'type': 'error', 'msg': '未找到 API Key'}) yield 'data: [DONE]\n\n' return yield _sse_event({'type': 'meta', 'model': model, 'stream': True}) parts = [] async for evt in _iter_upstream_stream(api_url, api_key, { 'model': model, 'messages': messages, }): if evt.get('type') == 'error': yield _sse_event(evt) yield 'data: [DONE]\n\n' return if evt.get('type') == 'content': parts.append(evt['content']) yield _sse_event(evt) reply = ''.join(parts) yield _sse_event({'type': 'done', 'reply': reply, 'model': model}) yield 'data: [DONE]\n\n' except Exception: yield _sse_event({'type': 'error', 'msg': traceback.format_exc()}) yield 'data: [DONE]\n\n' async def inference(request, *args, params_kw=None, **kw): from functools import partial env = request._run_ns.copy() f = partial(inference_generator, request, params_kw=params_kw, **kw) return await env.stream_response(request, f, content_type='text/event-stream') ret = await inference(request, params_kw=params_kw) return ret