This commit is contained in:
ymq1 2025-09-16 14:30:04 +08:00
parent 0399ca60dd
commit cb8c4d44c3
3 changed files with 18 additions and 18 deletions

View File

@ -10,22 +10,22 @@ from llmengine.base_chat_llm import BaseChatLLM, llm_register
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
class BaichuanM2LLM(BaseChatLLM): class BaichuanM2LLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
torch_dtype="auto", torch_dtype="auto",
device_map="auto" device_map="auto"
) )
self.model_id = model_id self.model_id = model_id
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
generate_kwargs = dict( generate_kwargs = dict(
**inputs, **inputs,
streamer=streamer, streamer=streamer,
max_new_tokens=4096 max_new_tokens=4096
) )
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
text = tokenizer.apply_chat_template( text = tokenizer.apply_chat_template(

View File

@ -150,14 +150,14 @@ class BaseChatLLM:
for d in self.output_generator(streamer): for d in self.output_generator(streamer):
if i == 0: if i == 0:
i = 1 i = 1
t1 = time() t2 = time()
if d['choices'][0]['finish_reason'] != 'stop': if d['choices'][0]['finish_reason'] != 'stop':
txt += d['choices'][0]['delta']['content'] txt += d['choices'][0]['delta']['content']
else: else:
o_tokens = d.get('output_tokens', 0) o_tokens = d.get('output_tokens', 0)
i_tokens = input_len i_tokens = input_len
t2 = time() t3 = time()
return { return {
'id': f'chatcmpl-{getID()}', 'id': f'chatcmpl-{getID()}',
"object":"chat.completion", "object":"chat.completion",

View File

@ -8,7 +8,7 @@ from llmengine.gemma3_it import Gemma3LLM
from llmengine.medgemma3_it import MedgemmaLLM from llmengine.medgemma3_it import MedgemmaLLM
from llmengine.qwen3 import Qwen3LLM from llmengine.qwen3 import Qwen3LLM
from llmengine.qwen3coder import Qwen3CoderLLM from llmengine.qwen3coder import Qwen3CoderLLM
from llmengine.baiduanm2 import BaichuanM2LLM from llmengine.baichuanm2 import BaichuanM2LLM
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception from appPublic.log import debug, exception