71 lines
1.7 KiB
Python
71 lines
1.7 KiB
Python
import time
|
||
from fastchat.serve.inference import ChatIO
|
||
from fastchat.model.model_adapter import (
|
||
load_model,
|
||
# get_conversation_template,
|
||
get_generate_stream_function,
|
||
)
|
||
from ahserver.serverenv import ServerEnv
|
||
from appPublic.worker import awaitify
|
||
|
||
class FastChatModel:
|
||
def __init__(self, model_path, device='cpu',
|
||
temperature=1.0,
|
||
context_len=100000,
|
||
debug=False
|
||
):
|
||
self.model_path = model_path
|
||
self.device = device
|
||
self.debug = debug
|
||
self.temperature=temperature
|
||
self.context_len = context_len
|
||
self.model, self.tokenizer = load_model(
|
||
model_path,
|
||
device=device,
|
||
debug=debug
|
||
)
|
||
self.generate_stream_func = get_generate_stream_function(self.model, self.model_path)
|
||
|
||
def _generate(self, prompt):
|
||
gen_params = {
|
||
"model": self.model_path,
|
||
"prompt": prompt,
|
||
"temperature": self.temperature,
|
||
"max_new_tokens": self.context_len,
|
||
"stream":False,
|
||
"echo": False,
|
||
}
|
||
|
||
output_stream = self.generate_stream_func(
|
||
self.model,
|
||
self.tokenizer,
|
||
gen_params,
|
||
self.device,
|
||
context_len=self.context_len
|
||
)
|
||
t = time.time()
|
||
output = ''
|
||
for i,s in enumerate(output_stream):
|
||
if self.debug:
|
||
print(i, ':', s['text'], '\n')
|
||
return s['text']
|
||
|
||
generate = awaitify(_generate)
|
||
|
||
g = ServerEnv()
|
||
# m = FastChatModel('./vicuna-7b-v1.5', device='mps')
|
||
# m = FastChatModel('/Users/ymq/models/hub/CodeLlama-13b-Instruct-hf')
|
||
m = FastChatModel('/Users/ymq/models/hub/CodeLlama-13-hf')
|
||
g.fschat = m
|
||
|
||
if __name__ == '__main__':
|
||
import asyncio
|
||
async def main():
|
||
while True:
|
||
print('input prompt:')
|
||
p = input()
|
||
x = await m.generate(p)
|
||
print(f'answer:\n{x}')
|
||
|
||
asyncio.get_event_loop().run_until_complete(main())
|