kboss/b/cntoai/model_management_search.dspy

98 lines
4.2 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 可写入/更新的字段(不含 id、created_at、updated_at
_MODEL_FIELDS = (
'llmid', 'provider', 'model_name', 'display_name', 'model_logo', 'model_type',
'context_length', 'input_token_price', 'output_token_price',
'cache_hit_input_price', 'billing_method', 'billing_unit',
'capabilities', 'limitations', 'highlights', 'is_active',
'description', 'listing_status', 'sort_order',
)
def _escape(value):
if value is None:
return None
return str(value).replace("'", "''")
def _build_model_dict(ns, include_listing_status=False):
data = {}
for field in _MODEL_FIELDS:
if field in ns and ns.get(field) is not None and ns.get(field) != '':
data[field] = ns.get(field)
if include_listing_status and 'listing_status' not in data:
data['listing_status'] = ns.get('listing_status', 0)
return data
async def model_management_search(ns={}):
"""
分页查询模型列表,支持按 model_name / model_type / provider 筛选。
返回模型总数、待上架总数、已上架总数,以及厂商列表、模型类型列表。
"""
import traceback
page_size = int(ns.get('page_size', 1000))
current_page = int(ns.get('current_page', 1))
offset = (current_page - 1) * page_size
conditions = ['1=1']
if ns.get('model_name'):
ns['display_name'] = ns.get('model_name')
if ns.get('display_name'):
display_name = ns.get('display_name')
# 模糊查询
conditions.append(f"display_name LIKE '%%%%{display_name}%%%%'")
if ns.get('model_type'):
conditions.append("model_type = '%s'" % _escape(ns.get('model_type')))
if ns.get('provider'):
conditions.append("provider = '%s'" % _escape(ns.get('provider')))
if ns.get('listing_status') is not None and ns.get('listing_status') != '':
conditions.append("listing_status = '%s'" % _escape(ns.get('listing_status')))
where_clause = ' AND '.join(conditions)
db = DBPools()
async with db.sqlorContext('kboss') as sor:
try:
stats_sql = """SELECT COUNT(*) AS total_count, SUM(CASE WHEN listing_status = 0 THEN 1 ELSE 0 END) AS pending_count, SUM(CASE WHEN listing_status = 1 THEN 1 ELSE 0 END) AS listed_count FROM model_management;"""
stats_li = await sor.sqlExe(stats_sql, {})
stats = stats_li[0] if stats_li else {}
provider_sql = """
SELECT DISTINCT provider FROM model_management
WHERE provider IS NOT NULL AND provider != ''
ORDER BY provider;
"""
model_type_sql = """
SELECT DISTINCT model_type FROM model_management
WHERE model_type IS NOT NULL AND model_type != ''
ORDER BY model_type;
"""
count_sql = """SELECT COUNT(*) AS total_count FROM model_management WHERE %s;""" % where_clause
filter_total = (await sor.sqlExe(count_sql, {}))[0]['total_count']
find_sql = """SELECT * FROM model_management WHERE %s ORDER BY sort_order ASC LIMIT %s OFFSET %s;""" % (where_clause, page_size, offset)
provider_rows = await sor.sqlExe(provider_sql, {})
model_type_rows = await sor.sqlExe(model_type_sql, {})
model_list = await sor.sqlExe(find_sql, {})
return {
'status': True,
'msg': 'search model success',
'data': {
'total_count': stats.get('total_count', 0),
'pending_count': int(stats.get('pending_count') or 0),
'listed_count': int(stats.get('listed_count') or 0),
'provider_list': [r['provider'] for r in provider_rows],
'model_type_list': [r['model_type'] for r in model_type_rows],
'filter_total': filter_total,
'page_size': page_size,
'current_page': current_page,
'model_list': model_list,
},
}
except Exception as e:
return {'status': False, 'msg': 'search model failed, %s' % traceback.format_exc()}
ret = await model_management_search(params_kw)
return ret