Compare commits

..

2 Commits

Author SHA1 Message Date
b2e4c0befb Merge branch 'main' of https://git.opencomputing.cn/yumoqing/llmengine 2025-09-16 14:31:00 +08:00
cb8c4d44c3 bugfix 2025-09-16 14:30:04 +08:00
3 changed files with 18 additions and 18 deletions

View File

@ -10,22 +10,22 @@ from llmengine.base_chat_llm import BaseChatLLM, llm_register
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
class BaichuanM2LLM(BaseChatLLM): class BaichuanM2LLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
torch_dtype="auto", torch_dtype="auto",
device_map="auto" device_map="auto"
) )
self.model_id = model_id self.model_id = model_id
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
generate_kwargs = dict( generate_kwargs = dict(
**inputs, **inputs,
streamer=streamer, streamer=streamer,
max_new_tokens=4096 max_new_tokens=4096
) )
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
text = tokenizer.apply_chat_template( text = tokenizer.apply_chat_template(

View File

@ -150,14 +150,14 @@ class BaseChatLLM:
for d in self.output_generator(streamer): for d in self.output_generator(streamer):
if i == 0: if i == 0:
i = 1 i = 1
t1 = time() t2 = time()
if d['choices'][0]['finish_reason'] != 'stop': if d['choices'][0]['finish_reason'] != 'stop':
txt += d['choices'][0]['delta']['content'] txt += d['choices'][0]['delta']['content']
else: else:
o_tokens = d.get('output_tokens', 0) o_tokens = d.get('output_tokens', 0)
i_tokens = input_len i_tokens = input_len
t2 = time() t3 = time()
return { return {
'id': f'chatcmpl-{getID()}', 'id': f'chatcmpl-{getID()}',
"object":"chat.completion", "object":"chat.completion",

View File

@ -8,7 +8,7 @@ from llmengine.gemma3_it import Gemma3LLM
from llmengine.medgemma3_it import MedgemmaLLM from llmengine.medgemma3_it import MedgemmaLLM
from llmengine.qwen3 import Qwen3LLM from llmengine.qwen3 import Qwen3LLM
from llmengine.qwen3coder import Qwen3CoderLLM from llmengine.qwen3coder import Qwen3CoderLLM
from llmengine.baiduanm2 import BaichuanM2LLM from llmengine.baichuanm2 import BaichuanM2LLM
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception from appPublic.log import debug, exception