bugfix
This commit is contained in:
parent
c32c16512e
commit
864bfdf18c
@ -107,14 +107,15 @@ class BaseChatLLM:
|
||||
inputs = self._messages2inputs(messages)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
streamer = self.get_streamer()
|
||||
kwargs = self.build_kwargs(inputs, streamer)
|
||||
thread = threading.Thread(target=self.model.generate,
|
||||
kwargs=kwargs)
|
||||
thread.start()
|
||||
for d in self.output_generator(streamer):
|
||||
if d['choices'][0]['finish_reason'] == 'stop':
|
||||
d['input_tokens'] = input_len
|
||||
yield d
|
||||
with self.get_streamer() as streamer:
|
||||
kwargs = self.build_kwargs(inputs, streamer)
|
||||
thread = threading.Thread(target=self.model.generate,
|
||||
kwargs=kwargs)
|
||||
thread.start()
|
||||
for d in self.output_generator(streamer):
|
||||
if d['choices'][0]['finish_reason'] == 'stop':
|
||||
d['input_tokens'] = input_len
|
||||
yield d
|
||||
|
||||
async def async_gen(self, messages):
|
||||
async for d in stream.iterate(self._gen(messages)):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user