This commit is contained in:
yumoqing 2025-07-18 15:52:53 +08:00
parent 3d2f799eee
commit f84f4f14f9
8 changed files with 101 additions and 203 deletions

View File

@ -3,6 +3,7 @@ import asyncio
import json import json
import torch import torch
from time import time from time import time
from aiostream import stream
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from appPublic.log import debug from appPublic.log import debug
from appPublic.worker import awaitify from appPublic.worker import awaitify
@ -26,18 +27,6 @@ class BaseChatLLM:
device = torch.device("mps") device = torch.device("mps")
self.model = self.model.to(device) self.model = self.model.to(device)
def get_session_key(self):
return self.model_id + ':messages'
def _get_session_messages(self, session):
key = self.get_session_key()
messages = session.get(key) or []
return messages
def _set_session_messages(self, session, messages):
key = self.get_session_key()
session[key] = messages
def get_streamer(self): def get_streamer(self):
return TextIteratorStreamer( return TextIteratorStreamer(
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
@ -49,7 +38,7 @@ class BaseChatLLM:
all_txt = '' all_txt = ''
t1 = time() t1 = time()
i = 0 i = 0
id = f'chatllm-{getID}' id = f'chatllm-{getID()}'
for txt in streamer: for txt in streamer:
if txt == '': if txt == '':
continue continue
@ -60,15 +49,15 @@ class BaseChatLLM:
yield { yield {
"id":id, "id":id,
"object":"chat.completion.chunk", "object":"chat.completion.chunk",
"created":time(), "created": t1,
"model":self.model_id, "model":self.model_id,
"choices":[ "choices":[
{ {
"index":0, "index":0,
"delta":{ "delta":{
"role": "assistant",
"content":txt "content":txt
}, },
"logprobs":None,
"finish_reason":None "finish_reason":None
} }
] ]
@ -80,7 +69,7 @@ class BaseChatLLM:
yield { yield {
"id":id, "id":id,
"object":"chat.completion.chunk", "object":"chat.completion.chunk",
"created":time(), "created": t1,
"model":self.model_id, "model":self.model_id,
"response_time": t2 - t1, "response_time": t2 - t1,
"finish_time": t3 - t1, "finish_time": t3 - t1,
@ -91,69 +80,11 @@ class BaseChatLLM:
"delta":{ "delta":{
"content":"" "content":""
}, },
"logprobs":None,
"finish_reason":"stop" "finish_reason":"stop"
} }
] ]
} }
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
messages = self._get_session_messages(session)
if sys_prompt:
messages.append(self._build_sys_message(sys_prompt))
messages.append(self._build_user_message(prompt, image_path=image_path))
# debug(f'{messages=}')
all_txt = ''
for d in self._gen(messages):
if d['choices'][0]['finish_reason'] == 'stop':
messages.append(self._build_assistant_message(all_txt))
else:
all_txt += d['choices'][0]['delta']['content']
yield d
self._set_session_messages(session, messages)
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
yield d
def generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
if d['choices'][0]['finish_reason'] == 'stop':
return d
def stream_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data: {json.dumps(d)}\n'
yield s
async def async_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
if d['choices'][0]['finish_reason'] == 'stop':
return d
async def async_stream_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data: {json.dumps(d)}\n'
yield s
yield 'data: [DONE]'
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
generate_kwargs = dict( generate_kwargs = dict(
**inputs, **inputs,
@ -165,6 +96,7 @@ class BaseChatLLM:
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
debug(f'{messages=}')
return self.processor.apply_chat_template( return self.processor.apply_chat_template(
messages, add_generation_prompt=True, messages, add_generation_prompt=True,
tokenize=True, tokenize=True,
@ -184,63 +116,70 @@ class BaseChatLLM:
d['input_tokens'] = input_len d['input_tokens'] = input_len
yield d yield d
class T2TChatLLM(BaseChatLLM): async def async_gen(self, messages):
def _build_assistant_message(self, prompt): async for d in stream.iterate(self._gen(messages)):
return { yield d
"role":"assistant",
"content":prompt
}
def _build_sys_message(self, prompt): async def chat_completion_stream(self, messages):
return { async for d in self.async_gen(messages):
"role":"system", if d['choices'][0]['finish_reason']:
"content": prompt d['usage'] = {
} 'prompt_tokens': d['input_tokens'],
'completion_tokens': d['output_tokens'],
'total_tokens': d['input_tokens'] + d['output_tokens']
}
s = f'data: {json.dumps(d)}\n'
yield s
yield 'data: [DONE]\n'
def _build_user_message(self, prompt, **kw): def reference(self, messages):
return { t1 = time()
"role":"user", inputs = self._messages2inputs(messages)
"content": prompt input_len = inputs["input_ids"].shape[-1]
} streamer = self.get_streamer()
kwargs = self.build_kwargs(inputs, streamer)
thread = threading.Thread(target=self.model.generate,
kwargs=kwargs)
thread.start()
txt = ''
i = 0
for d in self.output_generator(streamer):
if i == 0:
i = 1
t1 = time()
if d['choices'][0]['finish_reason'] != 'stop':
txt += d['choices'][0]['delta']['content']
else:
i_tokens = d['input_tokens']
o_tokens = d['output_tokens']
class MMChatLLM(BaseChatLLM): t2 = time()
""" multiple modal chat LLM """
def _build_assistant_message(self, prompt):
return { return {
"role":"assistant", 'id': f'chatcmpl-{getID()}',
"content":[{"type": "text", "text": prompt}] "object":"chat.completion",
} "created":t1,
"model":self.model_id,
def _build_sys_message(self, prompt): "response_time": t2 - t1,
return { "finish_time": t3 - t1,
"role":"system", "output_token": output_tokens,
"content":[{"type": "text", "text": prompt}] "choices":[
} {
"index":0,
def _build_user_message(self, prompt, image_path=None, "message":{
video_path=None, audio_path=None): "role": "assistant",
contents = [ "content": txt
{ },
"type":"text", "text": prompt "finish_reason":"stop"
}
],
"usage": {
"prompt_tokens": i_tokens,
"completion_tokens": o_tokens,
"total_tokens": i_tokens + o_tokens
} }
]
if image_path:
contents.append({
"type": "image",
"image": image_path
})
if video_path:
contents.append({
"type": "video",
"video":video_path
})
if audio_path:
contents.append({
"tyoe": "audio",
"audio": audio_path
})
return {
"role": "user",
"content": contents
} }
async def chat_completion(self, messages):
f = awaitify(self.reference)
return await f(messages)

View File

@ -1,4 +1,5 @@
import torch import torch
import numpy as np
model_pathMap = { model_pathMap = {
} }
@ -20,14 +21,19 @@ class BaseEmbedding:
device = torch.device("mps") device = torch.device("mps")
self.model = self.model.to(device) self.model = self.model.to(device)
def embeddings(self, input): def encode(self, input):
es = self.model.encode(input) es = self.model.encode(input)
def embeddings(self, input):
es = self.encode(input)
data = [] data = []
for i, e in enumerate(es): for i, e in enumerate(es):
if isinstance(e, np.ndarray):
r = e.tolist()
d = { d = {
"object": "embedding", "object": "embedding",
"index": i, "index": i,
"embedding": e.tolist() "embedding": e
} }
data.append(d) data.append(d)
return { return {

View File

@ -3,6 +3,7 @@ import os
import sys import sys
import argparse import argparse
from llmengine.qwen3embedding import * from llmengine.qwen3embedding import *
from llmengine.bgeembedding import *
from llmengine.base_embedding import get_llm_class from llmengine.base_embedding import get_llm_class
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction

View File

@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
from PIL import Image from PIL import Image
import requests import requests
import torch import torch
from llmengine.base_chat_llm import MMChatLLM, llm_register from llmengine.base_chat_llm import BaseChatLLM, llm_register
class Gemma3LLM(MMChatLLM): class Gemma3LLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.model = Gemma3ForConditionalGeneration.from_pretrained( self.model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto" model_id, device_map="auto"
@ -24,6 +24,10 @@ class Gemma3LLM(MMChatLLM):
llm_register("gemma-3", Gemma3LLM) llm_register("gemma-3", Gemma3LLM)
if __name__ == '__main__': if __name__ == '__main__':
def get_stream_text(chunk):
chunk = chunk[6:]
d = json.loads(chunk)
return d['choices'][0]['delta']['content']
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it') gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
session = {} session = {}
while True: while True:
@ -35,7 +39,7 @@ if __name__ == '__main__':
print('input image path') print('input image path')
imgpath=input() imgpath=input()
for d in gemma3.stream_generate(session, p, image_path=imgpath): for d in gemma3.stream_generate(session, p, image_path=imgpath):
if not d['done']: if not d['DONE']:
print(d['text'], end='', flush=True) print(d['text'], end='', flush=True)
else: else:
x = {k:v for k,v in d.items() if k != 'text'} x = {k:v for k,v in d.items() if k != 'text'}

View File

@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image from PIL import Image
import requests import requests
import torch import torch
from llmengine.base_chat_llm import MMChatLLM, llm_register from llmengine.base_chat_llm import BaseChatLLM, llm_register
model_id = "google/medgemma-4b-it" model_id = "google/medgemma-4b-it"
class MedgemmaLLM(MMChatLLM): class MedgemmaLLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.model = AutoModelForImageTextToText.from_pretrained( self.model = AutoModelForImageTextToText.from_pretrained(
model_id, model_id,

View File

@ -4,12 +4,12 @@
from appPublic.worker import awaitify from appPublic.worker import awaitify
from appPublic.log import debug from appPublic.log import debug
from ahserver.serverenv import get_serverenv from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image from PIL import Image
import torch import torch
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register from llmengine.base_chat_llm import BaseChatLLM, llm_register
from transformers import AutoModelForCausalLM, AutoTokenizer
class Qwen3LLM(T2TChatLLM): class Qwen3LLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
@ -17,9 +17,6 @@ class Qwen3LLM(T2TChatLLM):
torch_dtype="auto", torch_dtype="auto",
device_map="auto" device_map="auto"
) )
if torch.backends.mps.is_available():
device = torch.device("mps")
self.model = self.model.to(device)
self.model_id = model_id self.model_id = model_id
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
@ -33,7 +30,7 @@ class Qwen3LLM(T2TChatLLM):
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
debug(f'{messages=}') debug(f'-----------{messages=}-----------')
text = self.tokenizer.apply_chat_template( text = self.tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
@ -43,26 +40,3 @@ class Qwen3LLM(T2TChatLLM):
return self.tokenizer([text], return_tensors="pt").to(self.model.device) return self.tokenizer([text], return_tensors="pt").to(self.model.device)
llm_register("Qwen/Qwen3", Qwen3LLM) llm_register("Qwen/Qwen3", Qwen3LLM)
if __name__ == '__main__':
import sys
model_path = sys.argv[1]
q3 = Qwen3LLM(model_path)
session = {}
while True:
print('input prompt')
p = input()
if p:
if p == 'q':
break;
for d in q3.stream_generate(session, p):
print(d)
"""
if not d['done']:
print(d['text'], end='', flush=True)
else:
x = {k:v for k,v in d.items() if k != 'text'}
print(f'\n{x}\n')
"""

View File

@ -21,23 +21,17 @@ def init():
rf.register('chat_completions', chat_completions) rf.register('chat_completions', chat_completions)
async def chat_completions(request, params_kw, *params, **kw): async def chat_completions(request, params_kw, *params, **kw):
se = ServerEnv()
engine = se.engine
async def gor(): async def gor():
se = ServerEnv() async for d in engine.chat_completion_stream(params_kw.messages):
engine = se.chat_engine
session = await get_session(request)
kwargs = {
}
if params_kw.image_path:
kwargs['image_path'] = fs.reapPath(params_kw.image_path)
if params_kw.video_path:
kwargs['video_path'] = fs.reapPath(params_kw.video_path)
if params_kw.audio_path:
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
debug(f'{d=}') debug(f'{d=}')
yield d yield d
return await stream_response(request, gor) if params_kw.stream:
return await stream_response(request, gor)
else:
return await engine.chat_completion(params_kw.messages)
def main(): def main():
parser = argparse.ArgumentParser(prog="Sage") parser = argparse.ArgumentParser(prog="Sage")

View File

@ -1,23 +1,3 @@
[project]
name="llmengine"
version = "0.0.1"
description = "Your project description"
authors = [{ name = "yu moqing", email = "yumoqing@gmail.com" }]
readme = "README.md"
requires-python = ">=3.8"
license = {text = "MIT"}
dependencies = [
"torch",
"transformers",
"sentence-transformers>=2.7.0",
# "flash_attention_2",
"mistral-common",
"accelerate"
]
[project.optional-dependencies]
dev = ["pytest", "black", "mypy"]
[build-system] [build-system]
requires = ["setuptools>=61", "wheel"] requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"