diff --git a/llmengine/base_reranker.py b/llmengine/base_reranker.py index 268327b..73da050 100644 --- a/llmengine/base_reranker.py +++ b/llmengine/base_reranker.py @@ -2,83 +2,86 @@ import torch model_pathMap = { } + + def llm_register(model_key, Klass): - model_pathMap[model_key] = Klass + model_pathMap[model_key] = Klass + def get_llm_class(model_path): - for k,klass in model_pathMap.items(): - if len(model_path.split(k)) > 1: - return klass - print(f'{model_pathMap=}') - return None + for k, klass in model_pathMap.items(): + if len(model_path.split(k)) > 1: + return klass + print(f'{model_pathMap=}') + return None + class BaseReranker: - def __init__(self, model_id, **kw): - self.model_id = model_id - - def use_mps_if_prosible(self): - if torch.backends.mps.is_available(): - device = torch.device("mps") - self.model = self.model.to(device) + def __init__(self, model_id, **kw): + self.model_id = model_id - def process_inputs(self, pairs): - inputs = self.tokenizer( - pairs, padding=False, truncation='longest_first', - return_attention_mask=False, max_length=self.max_length - ) - inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length) - for key in inputs: - inputs[key] = inputs[key].to(self.model.device) - return inputs + def use_mps_if_prosible(self): + if torch.cuda.is_available(): + device = torch.device("cuda") + self.model = self.model.to(device, dtype=torch.float16) + else: + raise Exception("GPU not available, but required for FP16 inference") - def build_sys_prompt(self, sys_prompt): - return f"<|im_start|>system\n{sys_prompt}\n<|im_end|>" + def process_inputs(self, pairs): + inputs = self.tokenizer( + pairs, padding=False, truncation='longest_first', + return_attention_mask=False, max_length=self.max_length + ) + inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length) + for key in inputs: + inputs[key] = inputs[key].to(self.model.device) + return inputs - def build_user_prompt(self, query, doc, instruct=''): - return f'<|im_start|>user\n: {instruct}\n:{query}\n:\n{doc}<|im_end|>' - - def build_assistant_prompt(self): - return "<|im_start|>assistant\n\n\n\n\n" + def build_sys_prompt(self, sys_prompt): + return f"<|im_start|>system\n{sys_prompt}\n<|im_end|>" - def compute_logits(self, inputs, **kwargs): - batch_scores = self.model(**inputs).logits[:, -1, :] - # true_vector = batch_scores[:, token_true_id] - # false_vector = batch_scores[:, token_false_id] - # batch_scores = torch.stack([false_vector, true_vector], dim=1) - batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) - scores = batch_scores[:, 1].exp().tolist() - return scores + def build_user_prompt(self, query, doc, instruct=''): + return f'<|im_start|>user\n: {instruct}\n:{query}\n:\n{doc}<|im_end|>' - def build_pairs(self, query, docs, sys_prompt="", task=""): - sys_str = self.build_sys_prompt(sys_prompt) - ass_str = self.build_assistant_prompt() - pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ] - return pairs + def build_assistant_prompt(self): + return "<|im_start|>assistant\n\n\n\n\n" - def rerank(self, query, docs, top_n, sys_prompt="", task=""): - pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task) - with torch.no_grad(): - inputs = self.process_inputs(pairs) - scores = self.compute_logits(inputs) - data = [] - for i, s in enumerate(scores): - d = { - 'index':i, - 'relevance_score': s - } - data.append(d) - data = sorted(data, - key=lambda x: x["relevance_score"], - reverse=True) - if len(data) > top_n: - data = data[:top_n] - ret = { - "data": data, - "object": "rerank.result", - "model": self.model_name, - "usage": { - "prompt_tokens": 0, - "total_tokens": 0 - } - } - return ret + def compute_logits(self, inputs, **kwargs): + batch_scores = self.model(**inputs).logits[:, -1, :] + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + def build_pairs(self, query, docs, sys_prompt="", task=""): + sys_str = self.build_sys_prompt(sys_prompt) + ass_str = self.build_assistant_prompt() + pairs = [sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs] + return pairs + + def rerank(self, query, docs, top_n, sys_prompt="", task=""): + pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task) + with torch.no_grad(): + inputs = self.process_inputs(pairs) + scores = self.compute_logits(inputs) + data = [] + for i, s in enumerate(scores): + d = { + 'index': i, + 'relevance_score': s + } + data.append(d) + data = sorted(data, + key=lambda x: x["relevance_score"], + reverse=True) + if len(data) > top_n: + data = data[:top_n] + ret = { + "data": data, + "object": "rerank.result", + "model": self.model_name, + "usage": { + "prompt_tokens": 0, + "total_tokens": 0 + } + } + return ret \ No newline at end of file diff --git a/llmengine/bge_reranker.py b/llmengine/bge_reranker.py index 38486c4..3c1bd0b 100644 --- a/llmengine/bge_reranker.py +++ b/llmengine/bge_reranker.py @@ -1,31 +1,36 @@ import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from llmengine.base_reranker import BaseReranker, llm_register +from torch.cuda.amp import autocast class BgeReranker(BaseReranker): - def __init__(self, model_id, max_length=8096): - if 'bge-reranker' not in model_id: - e = Exception(f'{model_id} is not a bge-reranker') - raise e - self.tokenizer = AutoTokenizer.from_pretrained(model_id) - model = AutoModelForSequenceClassification.from_pretrained(model_id) - model.eval() - self.model = model - self.model_id = model_id - self.model_name = model_id.split('/')[-1] + def __init__(self, model_id, max_length=8096): + if 'bge-reranker' not in model_id: + e = Exception(f'{model_id} is not a bge-reranker') + raise e + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForSequenceClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16) + model.eval() + if torch.cuda.is_available(): + model = model.to('cuda') + self.model = model + self.model_id = model_id + self.model_name = model_id.split('/')[-1] - def build_pairs(self, query, docs, **kw): - return [[query, doc] for doc in docs] + def build_pairs(self, query, docs, **kw): + return [[query, doc] for doc in docs] - def process_inputs(self, pairs): - inputs = self.tokenizer(pairs, padding=True, - truncation=True, return_tensors='pt', max_length=512) - return inputs + def process_inputs(self, pairs): + inputs = self.tokenizer(pairs, padding=True, + truncation=True, return_tensors='pt', max_length=512) + if torch.cuda.is_available(): + inputs = {k: v.to('cuda') for k, v in inputs.items()} + return inputs - def compute_logits(self, inputs): - scores = self.model(**inputs, - return_dict=True).logits.view(-1, ).float() - scores = [ s.item() for s in scores ] - return scores + def compute_logits(self, inputs): + with autocast(): + scores = self.model(**inputs, return_dict=True).logits.view(-1,) + scores = [s.item() for s in scores] + return scores -llm_register('bge-reranker', BgeReranker) +llm_register('bge-reranker', BgeReranker) \ No newline at end of file diff --git a/llmengine/qwen3_reranker.py b/llmengine/qwen3_reranker.py index d72ef9d..b5f6c59 100644 --- a/llmengine/qwen3_reranker.py +++ b/llmengine/qwen3_reranker.py @@ -1,16 +1,122 @@ import torch -from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer from llmengine.base_reranker import BaseReranker, llm_register -class Qwen3Reranker(BaseReranker): - def __init__(self, model_id, max_length=8096): - if 'Qwen3-Reranker' not in model_id: - e = Exception(f'{model_id} is not a Qwen3-Reranker') - raise e - self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left') - self.model = AutoModelForCausalLM.from_pretrained(model_id).eval() - self.model_id = model_id - self.model_name = model_id.split('/')[-1] - self.max_length = 8192 -llm_register('Qwen3-Reranker', Qwen3Reranker) +class Qwen3Reranker(BaseReranker): + def __init__(self, model_id, max_length=1024): # 设置 max_length 为 1024 + if 'Qwen3-Reranker' not in model_id: + raise Exception(f'{model_id} is not a Qwen3-Reranker') + + self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left') + + # 使用 FP16(GPU) + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16 + ).eval() + + self.model_id = model_id + self.model_name = model_id.split('/')[-1] + self.max_length = max_length + + # 初始化前缀和后缀标记 + self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") + self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") + prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" + suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) + self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False) + + def format_instruction(self, instruction, query, doc): + if instruction is None: + instruction = 'Given a web search query, retrieve relevant passages that answer the query' + output = ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, query=query, doc=doc + ) + return output + + def process_inputs(self, pairs, batch_size=8): + all_inputs = [] + for i in range(0, len(pairs), batch_size): + batch_pairs = pairs[i:i + batch_size] + inputs = self.tokenizer( + batch_pairs, + padding=False, + truncation='longest_first', + return_attention_mask=False, + max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens) + ) + for j, ele in enumerate(inputs['input_ids']): + inputs['input_ids'][j] = self.prefix_tokens + ele + self.suffix_tokens + inputs = self.tokenizer.pad( + inputs, + padding=True, + return_tensors="pt", + max_length=self.max_length + ) + for key in inputs: + inputs[key] = inputs[key].to(self.model.device) + all_inputs.append(inputs) + # 清理内存 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return all_inputs + + def compute_logits(self, inputs, **kwargs): + batch_scores = self.model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, self.token_true_id] + false_vector = batch_scores[:, self.token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + def build_pairs(self, query, docs, sys_prompt="", task=""): + pairs = [self.format_instruction(task, query, doc) for doc in docs] + return pairs + + def rerank(self, query, docs, top_n, sys_prompt="", task="", batch_size=8): + if query is None: + raise Exception("query is None") + if docs is None or not docs: + raise Exception("documents is None or empty") + if not isinstance(docs, list): + docs = [docs] + if top_n is None or top_n <= 0: + top_n = len(docs) + + pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task) + print(f"Number of documents: {len(docs)}") + for i, p in enumerate(pairs): + print(f"Pair {i} token length: {len(self.tokenizer.encode(p))}") + scores = [] + with torch.no_grad(): + input_batches = self.process_inputs(pairs, batch_size=batch_size) + for inputs in input_batches: + batch_scores = self.compute_logits(inputs) + scores.extend(batch_scores) + # 清理内存 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + data = [ + {'index': i, 'relevance_score': s} + for i, s in enumerate(scores) + ] + data = sorted(data, key=lambda x: x["relevance_score"], reverse=True) + if len(data) > top_n: + data = data[:top_n] + + return { + "data": data, + "object": "rerank.result", + "model": self.model_name, + "usage": { + "prompt_tokens": sum(len(self.tokenizer.encode(p)) for p in pairs), + "total_tokens": sum(len(self.tokenizer.encode(p)) for p in pairs) + } + } + + +llm_register('Qwen3-Reranker', Qwen3Reranker) \ No newline at end of file diff --git a/test/reranker/start.sh b/test/reranker/start.sh index 4570dc0..3ad87b5 100755 --- a/test/reranker/start.sh +++ b/test/reranker/start.sh @@ -1,4 +1,3 @@ #!/bin/bash -# CUDA_VISIBLE_DEVICES=7 /share/vllm-0.8.5/bin/python -m llmengine.rerank -p 9997 /d/ymq/models/Qwen/Qwen3-Reranker-0___6B CUDA_VISIBLE_DEVICES=7 /share/vllm-0.8.5/bin/python -m llmengine.rerank -p 9997 /share/models/BAAI/bge-reranker-v2-m3