This commit is contained in:
yumoqing 2025-07-18 15:52:53 +08:00
parent 3d2f799eee
commit f84f4f14f9
8 changed files with 101 additions and 203 deletions

View File

@ -3,6 +3,7 @@ import asyncio
import json
import torch
from time import time
from aiostream import stream
from transformers import TextIteratorStreamer
from appPublic.log import debug
from appPublic.worker import awaitify
@ -26,18 +27,6 @@ class BaseChatLLM:
device = torch.device("mps")
self.model = self.model.to(device)
def get_session_key(self):
return self.model_id + ':messages'
def _get_session_messages(self, session):
key = self.get_session_key()
messages = session.get(key) or []
return messages
def _set_session_messages(self, session, messages):
key = self.get_session_key()
session[key] = messages
def get_streamer(self):
return TextIteratorStreamer(
tokenizer=self.tokenizer,
@ -49,7 +38,7 @@ class BaseChatLLM:
all_txt = ''
t1 = time()
i = 0
id = f'chatllm-{getID}'
id = f'chatllm-{getID()}'
for txt in streamer:
if txt == '':
continue
@ -60,15 +49,15 @@ class BaseChatLLM:
yield {
"id":id,
"object":"chat.completion.chunk",
"created":time(),
"created": t1,
"model":self.model_id,
"choices":[
{
"index":0,
"delta":{
"role": "assistant",
"content":txt
},
"logprobs":None,
"finish_reason":None
}
]
@ -80,7 +69,7 @@ class BaseChatLLM:
yield {
"id":id,
"object":"chat.completion.chunk",
"created":time(),
"created": t1,
"model":self.model_id,
"response_time": t2 - t1,
"finish_time": t3 - t1,
@ -91,69 +80,11 @@ class BaseChatLLM:
"delta":{
"content":""
},
"logprobs":None,
"finish_reason":"stop"
}
]
}
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
messages = self._get_session_messages(session)
if sys_prompt:
messages.append(self._build_sys_message(sys_prompt))
messages.append(self._build_user_message(prompt, image_path=image_path))
# debug(f'{messages=}')
all_txt = ''
for d in self._gen(messages):
if d['choices'][0]['finish_reason'] == 'stop':
messages.append(self._build_assistant_message(all_txt))
else:
all_txt += d['choices'][0]['delta']['content']
yield d
self._set_session_messages(session, messages)
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
yield d
def generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
if d['choices'][0]['finish_reason'] == 'stop':
return d
def stream_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data: {json.dumps(d)}\n'
yield s
async def async_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
if d['choices'][0]['finish_reason'] == 'stop':
return d
async def async_stream_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data: {json.dumps(d)}\n'
yield s
yield 'data: [DONE]'
def build_kwargs(self, inputs, streamer):
generate_kwargs = dict(
**inputs,
@ -165,6 +96,7 @@ class BaseChatLLM:
return generate_kwargs
def _messages2inputs(self, messages):
debug(f'{messages=}')
return self.processor.apply_chat_template(
messages, add_generation_prompt=True,
tokenize=True,
@ -184,63 +116,70 @@ class BaseChatLLM:
d['input_tokens'] = input_len
yield d
class T2TChatLLM(BaseChatLLM):
def _build_assistant_message(self, prompt):
return {
"role":"assistant",
"content":prompt
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content": prompt
}
def _build_user_message(self, prompt, **kw):
return {
"role":"user",
"content": prompt
}
class MMChatLLM(BaseChatLLM):
""" multiple modal chat LLM """
def _build_assistant_message(self, prompt):
return {
"role":"assistant",
"content":[{"type": "text", "text": prompt}]
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content":[{"type": "text", "text": prompt}]
}
def _build_user_message(self, prompt, image_path=None,
video_path=None, audio_path=None):
contents = [
{
"type":"text", "text": prompt
}
]
if image_path:
contents.append({
"type": "image",
"image": image_path
})
if video_path:
contents.append({
"type": "video",
"video":video_path
})
if audio_path:
contents.append({
"tyoe": "audio",
"audio": audio_path
})
return {
"role": "user",
"content": contents
}
async def async_gen(self, messages):
async for d in stream.iterate(self._gen(messages)):
yield d
async def chat_completion_stream(self, messages):
async for d in self.async_gen(messages):
if d['choices'][0]['finish_reason']:
d['usage'] = {
'prompt_tokens': d['input_tokens'],
'completion_tokens': d['output_tokens'],
'total_tokens': d['input_tokens'] + d['output_tokens']
}
s = f'data: {json.dumps(d)}\n'
yield s
yield 'data: [DONE]\n'
def reference(self, messages):
t1 = time()
inputs = self._messages2inputs(messages)
input_len = inputs["input_ids"].shape[-1]
streamer = self.get_streamer()
kwargs = self.build_kwargs(inputs, streamer)
thread = threading.Thread(target=self.model.generate,
kwargs=kwargs)
thread.start()
txt = ''
i = 0
for d in self.output_generator(streamer):
if i == 0:
i = 1
t1 = time()
if d['choices'][0]['finish_reason'] != 'stop':
txt += d['choices'][0]['delta']['content']
else:
i_tokens = d['input_tokens']
o_tokens = d['output_tokens']
t2 = time()
return {
'id': f'chatcmpl-{getID()}',
"object":"chat.completion",
"created":t1,
"model":self.model_id,
"response_time": t2 - t1,
"finish_time": t3 - t1,
"output_token": output_tokens,
"choices":[
{
"index":0,
"message":{
"role": "assistant",
"content": txt
},
"finish_reason":"stop"
}
],
"usage": {
"prompt_tokens": i_tokens,
"completion_tokens": o_tokens,
"total_tokens": i_tokens + o_tokens
}
}
async def chat_completion(self, messages):
f = awaitify(self.reference)
return await f(messages)

View File

@ -1,4 +1,5 @@
import torch
import numpy as np
model_pathMap = {
}
@ -20,14 +21,19 @@ class BaseEmbedding:
device = torch.device("mps")
self.model = self.model.to(device)
def embeddings(self, input):
def encode(self, input):
es = self.model.encode(input)
def embeddings(self, input):
es = self.encode(input)
data = []
for i, e in enumerate(es):
if isinstance(e, np.ndarray):
r = e.tolist()
d = {
"object": "embedding",
"index": i,
"embedding": e.tolist()
"embedding": e
}
data.append(d)
return {

View File

@ -3,6 +3,7 @@ import os
import sys
import argparse
from llmengine.qwen3embedding import *
from llmengine.bgeembedding import *
from llmengine.base_embedding import get_llm_class
from appPublic.registerfunction import RegisterFunction

View File

@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
from PIL import Image
import requests
import torch
from llmengine.base_chat_llm import MMChatLLM, llm_register
from llmengine.base_chat_llm import BaseChatLLM, llm_register
class Gemma3LLM(MMChatLLM):
class Gemma3LLM(BaseChatLLM):
def __init__(self, model_id):
self.model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
@ -24,6 +24,10 @@ class Gemma3LLM(MMChatLLM):
llm_register("gemma-3", Gemma3LLM)
if __name__ == '__main__':
def get_stream_text(chunk):
chunk = chunk[6:]
d = json.loads(chunk)
return d['choices'][0]['delta']['content']
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
session = {}
while True:
@ -35,7 +39,7 @@ if __name__ == '__main__':
print('input image path')
imgpath=input()
for d in gemma3.stream_generate(session, p, image_path=imgpath):
if not d['done']:
if not d['DONE']:
print(d['text'], end='', flush=True)
else:
x = {k:v for k,v in d.items() if k != 'text'}

View File

@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import requests
import torch
from llmengine.base_chat_llm import MMChatLLM, llm_register
from llmengine.base_chat_llm import BaseChatLLM, llm_register
model_id = "google/medgemma-4b-it"
class MedgemmaLLM(MMChatLLM):
class MedgemmaLLM(BaseChatLLM):
def __init__(self, model_id):
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,

View File

@ -4,12 +4,12 @@
from appPublic.worker import awaitify
from appPublic.log import debug
from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
from llmengine.base_chat_llm import BaseChatLLM, llm_register
from transformers import AutoModelForCausalLM, AutoTokenizer
class Qwen3LLM(T2TChatLLM):
class Qwen3LLM(BaseChatLLM):
def __init__(self, model_id):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
@ -17,9 +17,6 @@ class Qwen3LLM(T2TChatLLM):
torch_dtype="auto",
device_map="auto"
)
if torch.backends.mps.is_available():
device = torch.device("mps")
self.model = self.model.to(device)
self.model_id = model_id
def build_kwargs(self, inputs, streamer):
@ -33,7 +30,7 @@ class Qwen3LLM(T2TChatLLM):
return generate_kwargs
def _messages2inputs(self, messages):
debug(f'{messages=}')
debug(f'-----------{messages=}-----------')
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
@ -43,26 +40,3 @@ class Qwen3LLM(T2TChatLLM):
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
llm_register("Qwen/Qwen3", Qwen3LLM)
if __name__ == '__main__':
import sys
model_path = sys.argv[1]
q3 = Qwen3LLM(model_path)
session = {}
while True:
print('input prompt')
p = input()
if p:
if p == 'q':
break;
for d in q3.stream_generate(session, p):
print(d)
"""
if not d['done']:
print(d['text'], end='', flush=True)
else:
x = {k:v for k,v in d.items() if k != 'text'}
print(f'\n{x}\n')
"""

View File

@ -21,23 +21,17 @@ def init():
rf.register('chat_completions', chat_completions)
async def chat_completions(request, params_kw, *params, **kw):
se = ServerEnv()
engine = se.engine
async def gor():
se = ServerEnv()
engine = se.chat_engine
session = await get_session(request)
kwargs = {
}
if params_kw.image_path:
kwargs['image_path'] = fs.reapPath(params_kw.image_path)
if params_kw.video_path:
kwargs['video_path'] = fs.reapPath(params_kw.video_path)
if params_kw.audio_path:
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
async for d in engine.chat_completion_stream(params_kw.messages):
debug(f'{d=}')
yield d
return await stream_response(request, gor)
if params_kw.stream:
return await stream_response(request, gor)
else:
return await engine.chat_completion(params_kw.messages)
def main():
parser = argparse.ArgumentParser(prog="Sage")

View File

@ -1,23 +1,3 @@
[project]
name="llmengine"
version = "0.0.1"
description = "Your project description"
authors = [{ name = "yu moqing", email = "yumoqing@gmail.com" }]
readme = "README.md"
requires-python = ">=3.8"
license = {text = "MIT"}
dependencies = [
"torch",
"transformers",
"sentence-transformers>=2.7.0",
# "flash_attention_2",
"mistral-common",
"accelerate"
]
[project.optional-dependencies]
dev = ["pytest", "black", "mypy"]
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"