Compare commits
2 Commits
6b0a9e9cd0
...
b2e4c0befb
| Author | SHA1 | Date | |
|---|---|---|---|
| b2e4c0befb | |||
| cb8c4d44c3 |
@ -10,22 +10,22 @@ from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
class BaichuanM2LLM(BaseChatLLM):
|
class BaichuanM2LLM(BaseChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
torch_dtype="auto",
|
torch_dtype="auto",
|
||||||
device_map="auto"
|
device_map="auto"
|
||||||
)
|
)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
def build_kwargs(self, inputs, streamer):
|
||||||
generate_kwargs = dict(
|
generate_kwargs = dict(
|
||||||
**inputs,
|
**inputs,
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
max_new_tokens=4096
|
max_new_tokens=4096
|
||||||
)
|
)
|
||||||
return generate_kwargs
|
return generate_kwargs
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
def _messages2inputs(self, messages):
|
||||||
text = tokenizer.apply_chat_template(
|
text = tokenizer.apply_chat_template(
|
||||||
|
|||||||
@ -150,14 +150,14 @@ class BaseChatLLM:
|
|||||||
for d in self.output_generator(streamer):
|
for d in self.output_generator(streamer):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
i = 1
|
i = 1
|
||||||
t1 = time()
|
t2 = time()
|
||||||
if d['choices'][0]['finish_reason'] != 'stop':
|
if d['choices'][0]['finish_reason'] != 'stop':
|
||||||
txt += d['choices'][0]['delta']['content']
|
txt += d['choices'][0]['delta']['content']
|
||||||
else:
|
else:
|
||||||
o_tokens = d.get('output_tokens', 0)
|
o_tokens = d.get('output_tokens', 0)
|
||||||
i_tokens = input_len
|
i_tokens = input_len
|
||||||
|
|
||||||
t2 = time()
|
t3 = time()
|
||||||
return {
|
return {
|
||||||
'id': f'chatcmpl-{getID()}',
|
'id': f'chatcmpl-{getID()}',
|
||||||
"object":"chat.completion",
|
"object":"chat.completion",
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from llmengine.gemma3_it import Gemma3LLM
|
|||||||
from llmengine.medgemma3_it import MedgemmaLLM
|
from llmengine.medgemma3_it import MedgemmaLLM
|
||||||
from llmengine.qwen3 import Qwen3LLM
|
from llmengine.qwen3 import Qwen3LLM
|
||||||
from llmengine.qwen3coder import Qwen3CoderLLM
|
from llmengine.qwen3coder import Qwen3CoderLLM
|
||||||
from llmengine.baiduanm2 import BaichuanM2LLM
|
from llmengine.baichuanm2 import BaichuanM2LLM
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
from appPublic.log import debug, exception
|
from appPublic.log import debug, exception
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user