This commit is contained in:
yumoqing 2025-07-24 11:48:56 +08:00
parent c32c16512e
commit 864bfdf18c

View File

@ -107,14 +107,15 @@ class BaseChatLLM:
inputs = self._messages2inputs(messages) inputs = self._messages2inputs(messages)
input_len = inputs["input_ids"].shape[-1] input_len = inputs["input_ids"].shape[-1]
streamer = self.get_streamer() streamer = self.get_streamer()
kwargs = self.build_kwargs(inputs, streamer) with self.get_streamer() as streamer:
thread = threading.Thread(target=self.model.generate, kwargs = self.build_kwargs(inputs, streamer)
kwargs=kwargs) thread = threading.Thread(target=self.model.generate,
thread.start() kwargs=kwargs)
for d in self.output_generator(streamer): thread.start()
if d['choices'][0]['finish_reason'] == 'stop': for d in self.output_generator(streamer):
d['input_tokens'] = input_len if d['choices'][0]['finish_reason'] == 'stop':
yield d d['input_tokens'] = input_len
yield d
async def async_gen(self, messages): async def async_gen(self, messages):
async for d in stream.iterate(self._gen(messages)): async for d in stream.iterate(self._gen(messages)):