sage/fastchat/load_model.py
2025-07-16 14:28:41 +08:00

71 lines
1.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
from fastchat.serve.inference import ChatIO
from fastchat.model.model_adapter import (
load_model,
# get_conversation_template,
get_generate_stream_function,
)
from ahserver.serverenv import ServerEnv
from appPublic.worker import awaitify
class FastChatModel:
def __init__(self, model_path, device='cpu',
temperature=1.0,
context_len=100000,
debug=False
):
self.model_path = model_path
self.device = device
self.debug = debug
self.temperature=temperature
self.context_len = context_len
self.model, self.tokenizer = load_model(
model_path,
device=device,
debug=debug
)
self.generate_stream_func = get_generate_stream_function(self.model, self.model_path)
def _generate(self, prompt):
gen_params = {
"model": self.model_path,
"prompt": prompt,
"temperature": self.temperature,
"max_new_tokens": self.context_len,
"stream":False,
"echo": False,
}
output_stream = self.generate_stream_func(
self.model,
self.tokenizer,
gen_params,
self.device,
context_len=self.context_len
)
t = time.time()
output = ''
for i,s in enumerate(output_stream):
if self.debug:
print(i, ':', s['text'], '\n')
return s['text']
generate = awaitify(_generate)
g = ServerEnv()
# m = FastChatModel('./vicuna-7b-v1.5' device='mps')
# m = FastChatModel('/Users/ymq/models/hub/CodeLlama-13b-Instruct-hf')
m = FastChatModel('/Users/ymq/models/hub/CodeLlama-13-hf')
g.fschat = m
if __name__ == '__main__':
import asyncio
async def main():
while True:
print('input prompt:')
p = input()
x = await m.generate(p)
print(f'answer:\n{x}')
asyncio.get_event_loop().run_until_complete(main())