diff --git a/llmengine/baichuanm2.py b/llmengine/baichuanm2.py new file mode 100644 index 0000000..9112389 --- /dev/null +++ b/llmengine/baichuanm2.py @@ -0,0 +1,40 @@ +#!/share/vllm-0.8.5/bin/python + +# pip install accelerate +from appPublic.worker import awaitify +from appPublic.log import debug +from ahserver.serverenv import get_serverenv +from PIL import Image +import torch +from llmengine.base_chat_llm import BaseChatLLM, llm_register +from transformers import AutoModelForCausalLM, AutoTokenizer + +class BaichuanM2LLM(BaseChatLLM): + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype="auto", + device_map="auto" + ) + self.model_id = model_id + + def build_kwargs(self, inputs, streamer): + generate_kwargs = dict( + **inputs, + streamer=streamer, + max_new_tokens=4096 + ) + return generate_kwargs + + def _messages2inputs(self, messages): + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + thinking_mode='on' # on/off/auto + ) + model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + return model_inputs + +llm_register('Baichuan-M2', BaichuanM2LLM) diff --git a/llmengine/qwen3coder.py b/llmengine/qwen3coder.py index 359c9bf..b22e166 100644 --- a/llmengine/qwen3coder.py +++ b/llmengine/qwen3coder.py @@ -23,6 +23,7 @@ class Qwen3CoderLLM(BaseChatLLM): generate_kwargs = dict( **inputs, streamer=streamer, + num_return_sequences=2, # 并行生成 4 个结果 # do_sample=True, # eos_token_id=self.tokenizer.eos_token_id max_new_tokens=65536 diff --git a/llmengine/server.py b/llmengine/server.py index b7906a7..2741a07 100644 --- a/llmengine/server.py +++ b/llmengine/server.py @@ -8,6 +8,7 @@ from llmengine.gemma3_it import Gemma3LLM from llmengine.medgemma3_it import MedgemmaLLM from llmengine.qwen3 import Qwen3LLM from llmengine.qwen3coder import Qwen3CoderLLM +from llmengine.baiduanm2 import BaichuanM2LLM from appPublic.registerfunction import RegisterFunction from appPublic.log import debug, exception