This commit is contained in:
ymq1 2025-08-22 10:59:58 +08:00
parent 6b56718d2b
commit b8c52aa77a
3 changed files with 42 additions and 0 deletions

40
llmengine/baichuanm2.py Normal file
View File

@ -0,0 +1,40 @@
#!/share/vllm-0.8.5/bin/python
# pip install accelerate
from appPublic.worker import awaitify
from appPublic.log import debug
from ahserver.serverenv import get_serverenv
from PIL import Image
import torch
from llmengine.base_chat_llm import BaseChatLLM, llm_register
from transformers import AutoModelForCausalLM, AutoTokenizer
class BaichuanM2LLM(BaseChatLLM):
def __init__(self, model_id):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto"
)
self.model_id = model_id
def build_kwargs(self, inputs, streamer):
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=4096
)
return generate_kwargs
def _messages2inputs(self, messages):
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
thinking_mode='on' # on/off/auto
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
return model_inputs
llm_register('Baichuan-M2', BaichuanM2LLM)

View File

@ -23,6 +23,7 @@ class Qwen3CoderLLM(BaseChatLLM):
generate_kwargs = dict( generate_kwargs = dict(
**inputs, **inputs,
streamer=streamer, streamer=streamer,
num_return_sequences=2, # 并行生成 4 个结果
# do_sample=True, # do_sample=True,
# eos_token_id=self.tokenizer.eos_token_id # eos_token_id=self.tokenizer.eos_token_id
max_new_tokens=65536 max_new_tokens=65536

View File

@ -8,6 +8,7 @@ from llmengine.gemma3_it import Gemma3LLM
from llmengine.medgemma3_it import MedgemmaLLM from llmengine.medgemma3_it import MedgemmaLLM
from llmengine.qwen3 import Qwen3LLM from llmengine.qwen3 import Qwen3LLM
from llmengine.qwen3coder import Qwen3CoderLLM from llmengine.qwen3coder import Qwen3CoderLLM
from llmengine.baiduanm2 import BaichuanM2LLM
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception from appPublic.log import debug, exception