49 lines
1.7 KiB
Plaintext
49 lines
1.7 KiB
Plaintext
# 可写入/更新的字段(不含 id、created_at、updated_at)
|
||
_MODEL_FIELDS = (
|
||
'llmid', 'provider', 'model_name', 'display_name', 'model_type',
|
||
'context_length', 'input_token_price', 'output_token_price',
|
||
'cache_hit_input_price', 'billing_method', 'billing_unit',
|
||
'capabilities', 'limitations', 'highlights', 'is_active',
|
||
'description', 'listing_status',
|
||
)
|
||
|
||
|
||
def _escape(value):
|
||
if value is None:
|
||
return None
|
||
return str(value).replace("'", "''")
|
||
|
||
|
||
def _build_model_dict(ns, include_listing_status=False):
|
||
data = {}
|
||
for field in _MODEL_FIELDS:
|
||
if field in ns and ns.get(field) is not None and ns.get(field) != '':
|
||
data[field] = ns.get(field)
|
||
if include_listing_status and 'listing_status' not in data:
|
||
data['listing_status'] = ns.get('listing_status', 0)
|
||
return data
|
||
|
||
|
||
async def model_management_add(ns={}):
|
||
"""新增模型,默认待上架 listing_status=0"""
|
||
if not ns.get('provider') or not ns.get('model_name'):
|
||
return {'status': False, 'msg': 'provider and model_name are required'}
|
||
|
||
ns_dic = _build_model_dict(ns, include_listing_status=True)
|
||
if 'listing_status' not in ns_dic:
|
||
ns_dic['listing_status'] = 0
|
||
if 'is_active' not in ns_dic:
|
||
ns_dic['is_active'] = 1
|
||
|
||
db = DBPools()
|
||
async with db.sqlorContext('kboss') as sor:
|
||
try:
|
||
await sor.C('model_management', ns_dic)
|
||
return {'status': True, 'msg': 'create model success', 'data': ns_dic}
|
||
except Exception as e:
|
||
await sor.rollback()
|
||
return {'status': False, 'msg': 'create model failed, %s' % str(e)}
|
||
|
||
|
||
ret = await model_management_add(params_kw)
|
||
return ret |