bugfix
This commit is contained in:
parent
6b56718d2b
commit
b8c52aa77a
40
llmengine/baichuanm2.py
Normal file
40
llmengine/baichuanm2.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
#!/share/vllm-0.8.5/bin/python
|
||||||
|
|
||||||
|
# pip install accelerate
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug
|
||||||
|
from ahserver.serverenv import get_serverenv
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
class BaichuanM2LLM(BaseChatLLM):
|
||||||
|
def __init__(self, model_id):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
self.model_id = model_id
|
||||||
|
|
||||||
|
def build_kwargs(self, inputs, streamer):
|
||||||
|
generate_kwargs = dict(
|
||||||
|
**inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=4096
|
||||||
|
)
|
||||||
|
return generate_kwargs
|
||||||
|
|
||||||
|
def _messages2inputs(self, messages):
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
thinking_mode='on' # on/off/auto
|
||||||
|
)
|
||||||
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
llm_register('Baichuan-M2', BaichuanM2LLM)
|
||||||
@ -23,6 +23,7 @@ class Qwen3CoderLLM(BaseChatLLM):
|
|||||||
generate_kwargs = dict(
|
generate_kwargs = dict(
|
||||||
**inputs,
|
**inputs,
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
|
num_return_sequences=2, # 并行生成 4 个结果
|
||||||
# do_sample=True,
|
# do_sample=True,
|
||||||
# eos_token_id=self.tokenizer.eos_token_id
|
# eos_token_id=self.tokenizer.eos_token_id
|
||||||
max_new_tokens=65536
|
max_new_tokens=65536
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from llmengine.gemma3_it import Gemma3LLM
|
|||||||
from llmengine.medgemma3_it import MedgemmaLLM
|
from llmengine.medgemma3_it import MedgemmaLLM
|
||||||
from llmengine.qwen3 import Qwen3LLM
|
from llmengine.qwen3 import Qwen3LLM
|
||||||
from llmengine.qwen3coder import Qwen3CoderLLM
|
from llmengine.qwen3coder import Qwen3CoderLLM
|
||||||
|
from llmengine.baiduanm2 import BaichuanM2LLM
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
from appPublic.log import debug, exception
|
from appPublic.log import debug, exception
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user