From 864bfdf18cbbccf60e239814ae6617f07cf92643 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Thu, 24 Jul 2025 11:48:56 +0800 Subject: [PATCH] bugfix --- llmengine/base_chat_llm.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index ee62d86..958ded9 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -107,14 +107,15 @@ class BaseChatLLM: inputs = self._messages2inputs(messages) input_len = inputs["input_ids"].shape[-1] streamer = self.get_streamer() - kwargs = self.build_kwargs(inputs, streamer) - thread = threading.Thread(target=self.model.generate, - kwargs=kwargs) - thread.start() - for d in self.output_generator(streamer): - if d['choices'][0]['finish_reason'] == 'stop': - d['input_tokens'] = input_len - yield d + with self.get_streamer() as streamer: + kwargs = self.build_kwargs(inputs, streamer) + thread = threading.Thread(target=self.model.generate, + kwargs=kwargs) + thread.start() + for d in self.output_generator(streamer): + if d['choices'][0]['finish_reason'] == 'stop': + d['input_tokens'] = input_len + yield d async def async_gen(self, messages): async for d in stream.iterate(self._gen(messages)):