import time from fastchat.serve.inference import ChatIO from fastchat.model.model_adapter import ( load_model, # get_conversation_template, get_generate_stream_function, ) from ahserver.serverenv import ServerEnv from appPublic.worker import awaitify class FastChatModel: def __init__(self, model_path, device='cpu', temperature=1.0, context_len=100000, debug=False ): self.model_path = model_path self.device = device self.debug = debug self.temperature=temperature self.context_len = context_len self.model, self.tokenizer = load_model( model_path, device=device, debug=debug ) self.generate_stream_func = get_generate_stream_function(self.model, self.model_path) def _generate(self, prompt): gen_params = { "model": self.model_path, "prompt": prompt, "temperature": self.temperature, "max_new_tokens": self.context_len, "stream":False, "echo": False, } output_stream = self.generate_stream_func( self.model, self.tokenizer, gen_params, self.device, context_len=self.context_len ) t = time.time() output = '' for i,s in enumerate(output_stream): if self.debug: print(i, ':', s['text'], '\n') return s['text'] generate = awaitify(_generate) g = ServerEnv() # m = FastChatModel('./vicuna-7b-v1.5', device='mps') # m = FastChatModel('/Users/ymq/models/hub/CodeLlama-13b-Instruct-hf') m = FastChatModel('/Users/ymq/models/hub/CodeLlama-13-hf') g.fschat = m if __name__ == '__main__': import asyncio async def main(): while True: print('input prompt:') p = input() x = await m.generate(p) print(f'answer:\n{x}') asyncio.get_event_loop().run_until_complete(main())