bugfix
This commit is contained in:
parent
3d2f799eee
commit
f84f4f14f9
@ -3,6 +3,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
from time import time
|
from time import time
|
||||||
|
from aiostream import stream
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
from appPublic.log import debug
|
from appPublic.log import debug
|
||||||
from appPublic.worker import awaitify
|
from appPublic.worker import awaitify
|
||||||
@ -26,18 +27,6 @@ class BaseChatLLM:
|
|||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def get_session_key(self):
|
|
||||||
return self.model_id + ':messages'
|
|
||||||
|
|
||||||
def _get_session_messages(self, session):
|
|
||||||
key = self.get_session_key()
|
|
||||||
messages = session.get(key) or []
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def _set_session_messages(self, session, messages):
|
|
||||||
key = self.get_session_key()
|
|
||||||
session[key] = messages
|
|
||||||
|
|
||||||
def get_streamer(self):
|
def get_streamer(self):
|
||||||
return TextIteratorStreamer(
|
return TextIteratorStreamer(
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
@ -49,7 +38,7 @@ class BaseChatLLM:
|
|||||||
all_txt = ''
|
all_txt = ''
|
||||||
t1 = time()
|
t1 = time()
|
||||||
i = 0
|
i = 0
|
||||||
id = f'chatllm-{getID}'
|
id = f'chatllm-{getID()}'
|
||||||
for txt in streamer:
|
for txt in streamer:
|
||||||
if txt == '':
|
if txt == '':
|
||||||
continue
|
continue
|
||||||
@ -60,15 +49,15 @@ class BaseChatLLM:
|
|||||||
yield {
|
yield {
|
||||||
"id":id,
|
"id":id,
|
||||||
"object":"chat.completion.chunk",
|
"object":"chat.completion.chunk",
|
||||||
"created":time(),
|
"created": t1,
|
||||||
"model":self.model_id,
|
"model":self.model_id,
|
||||||
"choices":[
|
"choices":[
|
||||||
{
|
{
|
||||||
"index":0,
|
"index":0,
|
||||||
"delta":{
|
"delta":{
|
||||||
|
"role": "assistant",
|
||||||
"content":txt
|
"content":txt
|
||||||
},
|
},
|
||||||
"logprobs":None,
|
|
||||||
"finish_reason":None
|
"finish_reason":None
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -80,7 +69,7 @@ class BaseChatLLM:
|
|||||||
yield {
|
yield {
|
||||||
"id":id,
|
"id":id,
|
||||||
"object":"chat.completion.chunk",
|
"object":"chat.completion.chunk",
|
||||||
"created":time(),
|
"created": t1,
|
||||||
"model":self.model_id,
|
"model":self.model_id,
|
||||||
"response_time": t2 - t1,
|
"response_time": t2 - t1,
|
||||||
"finish_time": t3 - t1,
|
"finish_time": t3 - t1,
|
||||||
@ -91,69 +80,11 @@ class BaseChatLLM:
|
|||||||
"delta":{
|
"delta":{
|
||||||
"content":""
|
"content":""
|
||||||
},
|
},
|
||||||
"logprobs":None,
|
|
||||||
"finish_reason":"stop"
|
"finish_reason":"stop"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
messages = self._get_session_messages(session)
|
|
||||||
if sys_prompt:
|
|
||||||
messages.append(self._build_sys_message(sys_prompt))
|
|
||||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
|
||||||
# debug(f'{messages=}')
|
|
||||||
all_txt = ''
|
|
||||||
for d in self._gen(messages):
|
|
||||||
if d['choices'][0]['finish_reason'] == 'stop':
|
|
||||||
messages.append(self._build_assistant_message(all_txt))
|
|
||||||
else:
|
|
||||||
all_txt += d['choices'][0]['delta']['content']
|
|
||||||
yield d
|
|
||||||
self._set_session_messages(session, messages)
|
|
||||||
|
|
||||||
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
yield d
|
|
||||||
|
|
||||||
def generate(self, session, prompt,
|
|
||||||
image_path=None,
|
|
||||||
video_path=None,
|
|
||||||
audio_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
if d['choices'][0]['finish_reason'] == 'stop':
|
|
||||||
return d
|
|
||||||
def stream_generate(self, session, prompt,
|
|
||||||
image_path=None,
|
|
||||||
video_path=None,
|
|
||||||
audio_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
s = f'data: {json.dumps(d)}\n'
|
|
||||||
yield s
|
|
||||||
|
|
||||||
async def async_generate(self, session, prompt,
|
|
||||||
image_path=None,
|
|
||||||
video_path=None,
|
|
||||||
audio_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
if d['choices'][0]['finish_reason'] == 'stop':
|
|
||||||
return d
|
|
||||||
|
|
||||||
async def async_stream_generate(self, session, prompt,
|
|
||||||
image_path=None,
|
|
||||||
video_path=None,
|
|
||||||
audio_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
s = f'data: {json.dumps(d)}\n'
|
|
||||||
yield s
|
|
||||||
yield 'data: [DONE]'
|
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
def build_kwargs(self, inputs, streamer):
|
||||||
generate_kwargs = dict(
|
generate_kwargs = dict(
|
||||||
**inputs,
|
**inputs,
|
||||||
@ -165,6 +96,7 @@ class BaseChatLLM:
|
|||||||
return generate_kwargs
|
return generate_kwargs
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
def _messages2inputs(self, messages):
|
||||||
|
debug(f'{messages=}')
|
||||||
return self.processor.apply_chat_template(
|
return self.processor.apply_chat_template(
|
||||||
messages, add_generation_prompt=True,
|
messages, add_generation_prompt=True,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
@ -184,63 +116,70 @@ class BaseChatLLM:
|
|||||||
d['input_tokens'] = input_len
|
d['input_tokens'] = input_len
|
||||||
yield d
|
yield d
|
||||||
|
|
||||||
class T2TChatLLM(BaseChatLLM):
|
async def async_gen(self, messages):
|
||||||
def _build_assistant_message(self, prompt):
|
async for d in stream.iterate(self._gen(messages)):
|
||||||
return {
|
yield d
|
||||||
"role":"assistant",
|
|
||||||
"content":prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_sys_message(self, prompt):
|
async def chat_completion_stream(self, messages):
|
||||||
return {
|
async for d in self.async_gen(messages):
|
||||||
"role":"system",
|
if d['choices'][0]['finish_reason']:
|
||||||
"content": prompt
|
d['usage'] = {
|
||||||
|
'prompt_tokens': d['input_tokens'],
|
||||||
|
'completion_tokens': d['output_tokens'],
|
||||||
|
'total_tokens': d['input_tokens'] + d['output_tokens']
|
||||||
}
|
}
|
||||||
|
s = f'data: {json.dumps(d)}\n'
|
||||||
|
yield s
|
||||||
|
yield 'data: [DONE]\n'
|
||||||
|
|
||||||
def _build_user_message(self, prompt, **kw):
|
def reference(self, messages):
|
||||||
|
t1 = time()
|
||||||
|
inputs = self._messages2inputs(messages)
|
||||||
|
input_len = inputs["input_ids"].shape[-1]
|
||||||
|
streamer = self.get_streamer()
|
||||||
|
kwargs = self.build_kwargs(inputs, streamer)
|
||||||
|
thread = threading.Thread(target=self.model.generate,
|
||||||
|
kwargs=kwargs)
|
||||||
|
thread.start()
|
||||||
|
txt = ''
|
||||||
|
i = 0
|
||||||
|
for d in self.output_generator(streamer):
|
||||||
|
if i == 0:
|
||||||
|
i = 1
|
||||||
|
t1 = time()
|
||||||
|
if d['choices'][0]['finish_reason'] != 'stop':
|
||||||
|
txt += d['choices'][0]['delta']['content']
|
||||||
|
else:
|
||||||
|
i_tokens = d['input_tokens']
|
||||||
|
o_tokens = d['output_tokens']
|
||||||
|
|
||||||
|
t2 = time()
|
||||||
return {
|
return {
|
||||||
"role":"user",
|
'id': f'chatcmpl-{getID()}',
|
||||||
"content": prompt
|
"object":"chat.completion",
|
||||||
}
|
"created":t1,
|
||||||
|
"model":self.model_id,
|
||||||
class MMChatLLM(BaseChatLLM):
|
"response_time": t2 - t1,
|
||||||
""" multiple modal chat LLM """
|
"finish_time": t3 - t1,
|
||||||
def _build_assistant_message(self, prompt):
|
"output_token": output_tokens,
|
||||||
return {
|
"choices":[
|
||||||
"role":"assistant",
|
|
||||||
"content":[{"type": "text", "text": prompt}]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_sys_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"system",
|
|
||||||
"content":[{"type": "text", "text": prompt}]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_user_message(self, prompt, image_path=None,
|
|
||||||
video_path=None, audio_path=None):
|
|
||||||
contents = [
|
|
||||||
{
|
{
|
||||||
"type":"text", "text": prompt
|
"index":0,
|
||||||
|
"message":{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": txt
|
||||||
|
},
|
||||||
|
"finish_reason":"stop"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": i_tokens,
|
||||||
|
"completion_tokens": o_tokens,
|
||||||
|
"total_tokens": i_tokens + o_tokens
|
||||||
}
|
}
|
||||||
]
|
|
||||||
if image_path:
|
|
||||||
contents.append({
|
|
||||||
"type": "image",
|
|
||||||
"image": image_path
|
|
||||||
})
|
|
||||||
if video_path:
|
|
||||||
contents.append({
|
|
||||||
"type": "video",
|
|
||||||
"video":video_path
|
|
||||||
})
|
|
||||||
if audio_path:
|
|
||||||
contents.append({
|
|
||||||
"tyoe": "audio",
|
|
||||||
"audio": audio_path
|
|
||||||
})
|
|
||||||
return {
|
|
||||||
"role": "user",
|
|
||||||
"content": contents
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def chat_completion(self, messages):
|
||||||
|
f = awaitify(self.reference)
|
||||||
|
return await f(messages)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
model_pathMap = {
|
model_pathMap = {
|
||||||
}
|
}
|
||||||
@ -20,14 +21,19 @@ class BaseEmbedding:
|
|||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def embeddings(self, input):
|
def encode(self, input):
|
||||||
es = self.model.encode(input)
|
es = self.model.encode(input)
|
||||||
|
|
||||||
|
def embeddings(self, input):
|
||||||
|
es = self.encode(input)
|
||||||
data = []
|
data = []
|
||||||
for i, e in enumerate(es):
|
for i, e in enumerate(es):
|
||||||
|
if isinstance(e, np.ndarray):
|
||||||
|
r = e.tolist()
|
||||||
d = {
|
d = {
|
||||||
"object": "embedding",
|
"object": "embedding",
|
||||||
"index": i,
|
"index": i,
|
||||||
"embedding": e.tolist()
|
"embedding": e
|
||||||
}
|
}
|
||||||
data.append(d)
|
data.append(d)
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
from llmengine.qwen3embedding import *
|
from llmengine.qwen3embedding import *
|
||||||
|
from llmengine.bgeembedding import *
|
||||||
from llmengine.base_embedding import get_llm_class
|
from llmengine.base_embedding import get_llm_class
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
|||||||
@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
||||||
|
|
||||||
class Gemma3LLM(MMChatLLM):
|
class Gemma3LLM(BaseChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||||
model_id, device_map="auto"
|
model_id, device_map="auto"
|
||||||
@ -24,6 +24,10 @@ class Gemma3LLM(MMChatLLM):
|
|||||||
llm_register("gemma-3", Gemma3LLM)
|
llm_register("gemma-3", Gemma3LLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
def get_stream_text(chunk):
|
||||||
|
chunk = chunk[6:]
|
||||||
|
d = json.loads(chunk)
|
||||||
|
return d['choices'][0]['delta']['content']
|
||||||
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
||||||
session = {}
|
session = {}
|
||||||
while True:
|
while True:
|
||||||
@ -35,7 +39,7 @@ if __name__ == '__main__':
|
|||||||
print('input image path')
|
print('input image path')
|
||||||
imgpath=input()
|
imgpath=input()
|
||||||
for d in gemma3.stream_generate(session, p, image_path=imgpath):
|
for d in gemma3.stream_generate(session, p, image_path=imgpath):
|
||||||
if not d['done']:
|
if not d['DONE']:
|
||||||
print(d['text'], end='', flush=True)
|
print(d['text'], end='', flush=True)
|
||||||
else:
|
else:
|
||||||
x = {k:v for k,v in d.items() if k != 'text'}
|
x = {k:v for k,v in d.items() if k != 'text'}
|
||||||
|
|||||||
@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
||||||
|
|
||||||
model_id = "google/medgemma-4b-it"
|
model_id = "google/medgemma-4b-it"
|
||||||
|
|
||||||
class MedgemmaLLM(MMChatLLM):
|
class MedgemmaLLM(BaseChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
|||||||
@ -4,12 +4,12 @@
|
|||||||
from appPublic.worker import awaitify
|
from appPublic.worker import awaitify
|
||||||
from appPublic.log import debug
|
from appPublic.log import debug
|
||||||
from ahserver.serverenv import get_serverenv
|
from ahserver.serverenv import get_serverenv
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
|
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
class Qwen3LLM(T2TChatLLM):
|
class Qwen3LLM(BaseChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
@ -17,9 +17,6 @@ class Qwen3LLM(T2TChatLLM):
|
|||||||
torch_dtype="auto",
|
torch_dtype="auto",
|
||||||
device_map="auto"
|
device_map="auto"
|
||||||
)
|
)
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
device = torch.device("mps")
|
|
||||||
self.model = self.model.to(device)
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
def build_kwargs(self, inputs, streamer):
|
||||||
@ -33,7 +30,7 @@ class Qwen3LLM(T2TChatLLM):
|
|||||||
return generate_kwargs
|
return generate_kwargs
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
def _messages2inputs(self, messages):
|
||||||
debug(f'{messages=}')
|
debug(f'-----------{messages=}-----------')
|
||||||
text = self.tokenizer.apply_chat_template(
|
text = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
@ -43,26 +40,3 @@ class Qwen3LLM(T2TChatLLM):
|
|||||||
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
||||||
|
|
||||||
llm_register("Qwen/Qwen3", Qwen3LLM)
|
llm_register("Qwen/Qwen3", Qwen3LLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import sys
|
|
||||||
model_path = sys.argv[1]
|
|
||||||
q3 = Qwen3LLM(model_path)
|
|
||||||
session = {}
|
|
||||||
while True:
|
|
||||||
print('input prompt')
|
|
||||||
p = input()
|
|
||||||
if p:
|
|
||||||
if p == 'q':
|
|
||||||
break;
|
|
||||||
for d in q3.stream_generate(session, p):
|
|
||||||
print(d)
|
|
||||||
"""
|
|
||||||
if not d['done']:
|
|
||||||
print(d['text'], end='', flush=True)
|
|
||||||
else:
|
|
||||||
x = {k:v for k,v in d.items() if k != 'text'}
|
|
||||||
print(f'\n{x}\n')
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -21,23 +21,17 @@ def init():
|
|||||||
rf.register('chat_completions', chat_completions)
|
rf.register('chat_completions', chat_completions)
|
||||||
|
|
||||||
async def chat_completions(request, params_kw, *params, **kw):
|
async def chat_completions(request, params_kw, *params, **kw):
|
||||||
async def gor():
|
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
engine = se.chat_engine
|
engine = se.engine
|
||||||
session = await get_session(request)
|
async def gor():
|
||||||
kwargs = {
|
async for d in engine.chat_completion_stream(params_kw.messages):
|
||||||
}
|
|
||||||
if params_kw.image_path:
|
|
||||||
kwargs['image_path'] = fs.reapPath(params_kw.image_path)
|
|
||||||
if params_kw.video_path:
|
|
||||||
kwargs['video_path'] = fs.reapPath(params_kw.video_path)
|
|
||||||
if params_kw.audio_path:
|
|
||||||
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
|
|
||||||
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
|
|
||||||
debug(f'{d=}')
|
debug(f'{d=}')
|
||||||
yield d
|
yield d
|
||||||
|
|
||||||
|
if params_kw.stream:
|
||||||
return await stream_response(request, gor)
|
return await stream_response(request, gor)
|
||||||
|
else:
|
||||||
|
return await engine.chat_completion(params_kw.messages)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(prog="Sage")
|
parser = argparse.ArgumentParser(prog="Sage")
|
||||||
|
|||||||
@ -1,23 +1,3 @@
|
|||||||
[project]
|
|
||||||
name="llmengine"
|
|
||||||
version = "0.0.1"
|
|
||||||
description = "Your project description"
|
|
||||||
authors = [{ name = "yu moqing", email = "yumoqing@gmail.com" }]
|
|
||||||
readme = "README.md"
|
|
||||||
requires-python = ">=3.8"
|
|
||||||
license = {text = "MIT"}
|
|
||||||
dependencies = [
|
|
||||||
"torch",
|
|
||||||
"transformers",
|
|
||||||
"sentence-transformers>=2.7.0",
|
|
||||||
# "flash_attention_2",
|
|
||||||
"mistral-common",
|
|
||||||
"accelerate"
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
dev = ["pytest", "black", "mypy"]
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=61", "wheel"]
|
requires = ["setuptools>=61", "wheel"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user