Compare commits

..

No commits in common. "b2e4c0befb2a359b6eeec370b6911a118dff8e43" and "6b0a9e9cd06efbfbaa0efde7bd062f4eab492d5b" have entirely different histories.

3 changed files with 18 additions and 18 deletions

View File

@ -10,22 +10,22 @@ from llmengine.base_chat_llm import BaseChatLLM, llm_register
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
class BaichuanM2LLM(BaseChatLLM): class BaichuanM2LLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
torch_dtype="auto", torch_dtype="auto",
device_map="auto" device_map="auto"
) )
self.model_id = model_id self.model_id = model_id
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
generate_kwargs = dict( generate_kwargs = dict(
**inputs, **inputs,
streamer=streamer, streamer=streamer,
max_new_tokens=4096 max_new_tokens=4096
) )
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
text = tokenizer.apply_chat_template( text = tokenizer.apply_chat_template(

View File

@ -150,14 +150,14 @@ class BaseChatLLM:
for d in self.output_generator(streamer): for d in self.output_generator(streamer):
if i == 0: if i == 0:
i = 1 i = 1
t2 = time() t1 = time()
if d['choices'][0]['finish_reason'] != 'stop': if d['choices'][0]['finish_reason'] != 'stop':
txt += d['choices'][0]['delta']['content'] txt += d['choices'][0]['delta']['content']
else: else:
o_tokens = d.get('output_tokens', 0) o_tokens = d.get('output_tokens', 0)
i_tokens = input_len i_tokens = input_len
t3 = time() t2 = time()
return { return {
'id': f'chatcmpl-{getID()}', 'id': f'chatcmpl-{getID()}',
"object":"chat.completion", "object":"chat.completion",

View File

@ -8,7 +8,7 @@ from llmengine.gemma3_it import Gemma3LLM
from llmengine.medgemma3_it import MedgemmaLLM from llmengine.medgemma3_it import MedgemmaLLM
from llmengine.qwen3 import Qwen3LLM from llmengine.qwen3 import Qwen3LLM
from llmengine.qwen3coder import Qwen3CoderLLM from llmengine.qwen3coder import Qwen3CoderLLM
from llmengine.baichuanm2 import BaichuanM2LLM from llmengine.baiduanm2 import BaichuanM2LLM
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception from appPublic.log import debug, exception