24 lines
667 B
Python
24 lines
667 B
Python
from appPublic.worker import awaitify
|
|
from ahserver.serverEnv import ServerEnv
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
|
class ChatGLM3:
|
|
def __init__(self, model_path, gpu=False):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
|
if gpu:
|
|
model = model.cuda()
|
|
else:
|
|
model = model.float()
|
|
model = model.eval()
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
|
|
def _generate(self, prompt, history=[]):
|
|
response, history = self.model.chat(self.tokenizer, prompt, history=history)
|
|
return response, history
|
|
|
|
generate = awaitify(_generate)
|
|
|
|
|