From cb8c4d44c3c3312345390ed4ffe3c177da864aeb Mon Sep 17 00:00:00 2001 From: ymq1 Date: Tue, 16 Sep 2025 14:30:04 +0800 Subject: [PATCH] bugfix --- llmengine/baichuanm2.py | 30 +++++++++++++++--------------- llmengine/base_chat_llm.py | 4 ++-- llmengine/server.py | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/llmengine/baichuanm2.py b/llmengine/baichuanm2.py index 9112389..89ac9e0 100644 --- a/llmengine/baichuanm2.py +++ b/llmengine/baichuanm2.py @@ -10,22 +10,22 @@ from llmengine.base_chat_llm import BaseChatLLM, llm_register from transformers import AutoModelForCausalLM, AutoTokenizer class BaichuanM2LLM(BaseChatLLM): - def __init__(self, model_id): - self.tokenizer = AutoTokenizer.from_pretrained(model_id) - self.model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype="auto", - device_map="auto" - ) - self.model_id = model_id + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype="auto", + device_map="auto" + ) + self.model_id = model_id - def build_kwargs(self, inputs, streamer): - generate_kwargs = dict( - **inputs, - streamer=streamer, - max_new_tokens=4096 - ) - return generate_kwargs + def build_kwargs(self, inputs, streamer): + generate_kwargs = dict( + **inputs, + streamer=streamer, + max_new_tokens=4096 + ) + return generate_kwargs def _messages2inputs(self, messages): text = tokenizer.apply_chat_template( diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index 7c3419a..ee62268 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -150,14 +150,14 @@ class BaseChatLLM: for d in self.output_generator(streamer): if i == 0: i = 1 - t1 = time() + t2 = time() if d['choices'][0]['finish_reason'] != 'stop': txt += d['choices'][0]['delta']['content'] else: o_tokens = d.get('output_tokens', 0) i_tokens = input_len - t2 = time() + t3 = time() return { 'id': f'chatcmpl-{getID()}', "object":"chat.completion", diff --git a/llmengine/server.py b/llmengine/server.py index 2741a07..4cd712c 100644 --- a/llmengine/server.py +++ b/llmengine/server.py @@ -8,7 +8,7 @@ from llmengine.gemma3_it import Gemma3LLM from llmengine.medgemma3_it import MedgemmaLLM from llmengine.qwen3 import Qwen3LLM from llmengine.qwen3coder import Qwen3CoderLLM -from llmengine.baiduanm2 import BaichuanM2LLM +from llmengine.baichuanm2 import BaichuanM2LLM from appPublic.registerfunction import RegisterFunction from appPublic.log import debug, exception