diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index ee62d86..958ded9 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -107,14 +107,15 @@ class BaseChatLLM: inputs = self._messages2inputs(messages) input_len = inputs["input_ids"].shape[-1] streamer = self.get_streamer() - kwargs = self.build_kwargs(inputs, streamer) - thread = threading.Thread(target=self.model.generate, - kwargs=kwargs) - thread.start() - for d in self.output_generator(streamer): - if d['choices'][0]['finish_reason'] == 'stop': - d['input_tokens'] = input_len - yield d + with self.get_streamer() as streamer: + kwargs = self.build_kwargs(inputs, streamer) + thread = threading.Thread(target=self.model.generate, + kwargs=kwargs) + thread.start() + for d in self.output_generator(streamer): + if d['choices'][0]['finish_reason'] == 'stop': + d['input_tokens'] = input_len + yield d async def async_gen(self, messages): async for d in stream.iterate(self._gen(messages)):