sage/chatglm3_6b/load_model.py
2025-07-16 14:28:41 +08:00

24 lines
667 B
Python

from appPublic.worker import awaitify
from ahserver.serverEnv import ServerEnv
from transformers import AutoTokenizer, AutoModel
class ChatGLM3:
def __init__(self, model_path, gpu=False):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
if gpu:
model = model.cuda()
else:
model = model.float()
model = model.eval()
self.model = model
self.tokenizer = tokenizer
def _generate(self, prompt, history=[]):
response, history = self.model.chat(self.tokenizer, prompt, history=history)
return response, history
generate = awaitify(_generate)