kboss/b/cntoai/model_management_search_doc.dspy
2026-05-27 14:25:10 +08:00

154 lines
6.0 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 可写入/更新的字段(不含 id、created_at、updated_at
_MODEL_FIELDS = (
'llmid', 'provider', 'model_name', 'display_name', 'model_type',
'context_length', 'input_token_price', 'output_token_price',
'cache_hit_input_price', 'billing_method', 'billing_unit',
'capabilities', 'limitations', 'highlights', 'is_active',
'description', 'listing_status', 'sort_order',
)
def _escape(value):
if value is None:
return None
return str(value).replace("'", "''")
def _build_model_dict(ns, include_listing_status=False):
data = {}
for field in _MODEL_FIELDS:
if field in ns and ns.get(field) is not None and ns.get(field) != '':
data[field] = ns.get(field)
if include_listing_status and 'listing_status' not in data:
data['listing_status'] = ns.get('listing_status', 0)
return data
def _build_where_conditions(ns, table_alias='m'):
"""构建 model_management 筛选条件(带表别名,用于 JOIN 查询)"""
prefix = '%s.' % table_alias
conditions = ['1=1']
if ns.get('display_name'):
display_name = _escape(ns.get('display_name'))
conditions.append("%sdisplay_name LIKE '%%%%%s%%%%'" % (prefix, display_name))
if ns.get('model_type'):
conditions.append("%smodel_type = '%s'" % (prefix, _escape(ns.get('model_type'))))
if ns.get('provider'):
conditions.append("%sprovider = '%s'" % (prefix, _escape(ns.get('provider'))))
if ns.get('listing_status') is not None and ns.get('listing_status') != '':
conditions.append("%slisting_status = '%s'" % (prefix, _escape(ns.get('listing_status'))))
return ' AND '.join(conditions)
def _attach_api_doc(row):
"""将 JOIN 出的 API 文档字段整理为 api_doc 子对象"""
api_doc_id = row.pop('api_doc_id', None)
api_url = row.pop('api_url', None)
curl_code = row.pop('curl_code', None)
python_code = row.pop('python_code', None)
api_doc_created_at = row.pop('api_doc_created_at', None)
api_doc_updated_at = row.pop('api_doc_updated_at', None)
if api_doc_id:
row['api_doc'] = {
'id': api_doc_id,
'model_id': str(row.get('id', '')),
'api_url': api_url,
'curl_code': curl_code,
'python_code': python_code,
'created_at': api_doc_created_at,
'updated_at': api_doc_updated_at,
}
else:
row['api_doc'] = None
return row
async def model_management_search_doc(ns={}):
"""
分页查询模型列表(含 API 文档),支持按 display_name / model_type / provider / listing_status 筛选。
model_management LEFT JOIN model_api_docmodel_id = model_management.id
返回模型总数、待上架总数、已上架总数,以及厂商列表、模型类型列表;每条模型含 api_doc。
"""
import traceback
page_size = int(ns.get('page_size', 1000))
current_page = int(ns.get('current_page', 1))
offset = (current_page - 1) * page_size
where_clause = _build_where_conditions(ns)
db = DBPools()
async with db.sqlorContext('kboss') as sor:
try:
stats_sql = """
SELECT COUNT(*) AS total_count,
SUM(CASE WHEN listing_status = 0 THEN 1 ELSE 0 END) AS pending_count,
SUM(CASE WHEN listing_status = 1 THEN 1 ELSE 0 END) AS listed_count
FROM model_management;
"""
stats_li = await sor.sqlExe(stats_sql, {})
stats = stats_li[0] if stats_li else {}
provider_sql = """
SELECT DISTINCT provider FROM model_management
WHERE provider IS NOT NULL AND provider != ''
ORDER BY provider;
"""
model_type_sql = """
SELECT DISTINCT model_type FROM model_management
WHERE model_type IS NOT NULL AND model_type != ''
ORDER BY model_type;
"""
count_sql = """
SELECT COUNT(*) AS total_count
FROM model_management m
WHERE %s;
""" % where_clause
find_sql = """
SELECT m.*,
d.id AS api_doc_id,
d.api_url,
d.curl_code,
d.python_code,
d.created_at AS api_doc_created_at,
d.updated_at AS api_doc_updated_at
FROM model_management m
LEFT JOIN model_api_doc d ON d.model_id = CAST(m.id AS CHAR)
WHERE %s
ORDER BY m.sort_order ASC
LIMIT %s OFFSET %s;
""" % (where_clause, page_size, offset)
provider_rows = await sor.sqlExe(provider_sql, {})
model_type_rows = await sor.sqlExe(model_type_sql, {})
filter_total = (await sor.sqlExe(count_sql, {}))[0]['total_count']
model_rows = await sor.sqlExe(find_sql, {})
model_list = [_attach_api_doc(row) for row in model_rows]
return {
'status': True,
'msg': 'search model with api doc success',
'data': {
'total_count': stats.get('total_count', 0),
'pending_count': int(stats.get('pending_count') or 0),
'listed_count': int(stats.get('listed_count') or 0),
'provider_list': [r['provider'] for r in provider_rows],
'model_type_list': [r['model_type'] for r in model_type_rows],
'filter_total': filter_total,
'page_size': page_size,
'current_page': current_page,
'model_list': model_list,
},
}
except Exception as e:
return {
'status': False,
'msg': 'search model with api doc failed, %s' % traceback.format_exc(),
}
ret = await model_management_search_doc(params_kw)
return ret