rag/rag/extract.py
2025-07-16 15:06:59 +08:00

225 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging
import yaml
import time
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
# 三元组保存路径
TRIPLES_OUTPUT_DIR = "/share/wangmeihua/rag/triples"
os.makedirs(TRIPLES_OUTPUT_DIR, exist_ok=True)
# 加载 mREBEL 模型和分词器
local_path = "/share/models/Babelscape/mrebel-large"
try:
tokenizer = AutoTokenizer.from_pretrained(local_path, src_lang="zh_CN", tgt_lang="tp_XX")
model = AutoModelForSeq2SeqLM.from_pretrained(local_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
triplet_id = tokenizer.convert_tokens_to_ids("<triplet>")
logger.debug(f"成功加载 mREBEL 模型,分词器 triplet_id: {triplet_id}")
except Exception as e:
logger.error(f"加载 mREBEL 模型失败: {str(e)}")
raise RuntimeError(f"加载 mREBEL 模型失败: {str(e)}")
# 优化生成参数
gen_kwargs = {
"max_length": 512,
"min_length": 10,
"length_penalty": 0.5,
"num_beams": 3,
"num_return_sequences": 1,
"no_repeat_ngram_size": 2,
"early_stopping": True,
"decoder_start_token_id": triplet_id,
}
def split_document(text: str, max_chunk_size: int = 150) -> list:
"""分割文档为语义完整的块"""
sentences = re.split(r'(?<=[。!?;\n])', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= max_chunk_size:
current_chunk += sentence
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk)
return chunks
def extract_triplets_typed(text: str) -> list:
"""解析 mREBEL 生成文本,匹配 <triplet> <entity1> <type1> <entity2> <type2> <relation> 格式"""
triplets = []
logger.debug(f"原始生成文本: {text}")
# 分割标记
tokens = []
in_tag = False
buffer = ""
for char in text:
if char == '<':
in_tag = True
if buffer:
tokens.append(buffer.strip())
buffer = ""
buffer += char
elif char == '>':
in_tag = False
buffer += char
tokens.append(buffer.strip())
buffer = ""
else:
buffer += char
if buffer:
tokens.append(buffer.strip())
# 过滤特殊标记
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
tokens = [t for t in tokens if t not in special_tokens and t]
logger.debug(f"处理后标记: {tokens}")
# 解析三元组
i = 0
while i < len(tokens):
if tokens[i] == "<triplet>" and i + 5 < len(tokens):
entity1 = tokens[i + 1]
type1 = tokens[i + 2][1:-1] if tokens[i + 2].startswith("<") and tokens[i + 2].endswith(">") else ""
entity2 = tokens[i + 3]
type2 = tokens[i + 4][1:-1] if tokens[i + 4].startswith("<") and tokens[i + 4].endswith(">") else ""
relation = tokens[i + 5]
if entity1 and type1 and entity2 and type2 and relation:
triplets.append({
'head': entity1.strip(),
'head_type': type1,
'type': relation.strip(),
'tail': entity2.strip(),
'tail_type': type2
})
logger.debug(f"添加三元组: {entity1}({type1}) - {relation} - {entity2}({type2})")
i += 6
else:
i += 1
return triplets
def extract_and_save_triplets(text: str, document_id: str, userid: str) -> bool:
"""
从文本中抽取三元组并保存到指定路径。
参数:
text (str): 输入文本
document_id (str): 文档ID
userid (str): 用户ID
返回:
bool: 三元组抽取和保存是否成功
"""
try:
if not text or not document_id or not userid:
raise ValueError("text、document_id 和 userid 不能为空")
if "_" in document_id or "_" in userid:
raise ValueError("document_id 和 userid 不能包含下划线")
start_time = time.time()
logger.info(f"开始抽取文档 {document_id} 的三元组userid: {userid}")
# 分割文本为语义块
text_chunks = split_document(text, max_chunk_size=150)
logger.debug(f"分割为 {len(text_chunks)} 个文本块")
# 处理所有文本块
all_triplets = []
for i, chunk in enumerate(text_chunks):
logger.debug(f"处理块 {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
# 分词
model_inputs = tokenizer(
chunk,
max_length=256,
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
# 生成
try:
generated_tokens = model.generate(
model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
**gen_kwargs,
)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
for idx, sentence in enumerate(decoded_preds):
logger.debug(f"{i + 1} 生成文本: {sentence}")
triplets = extract_triplets_typed(sentence)
if triplets:
logger.debug(f"{i + 1} 提取到 {len(triplets)} 个三元组")
all_triplets.extend(triplets)
except Exception as e:
logger.warning(f"处理块 {i + 1} 时出错: {str(e)}")
continue
# 去重
unique_triplets = []
seen = set()
for t in all_triplets:
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triplets.append(t)
# 保存结果
output_file = os.path.join(TRIPLES_OUTPUT_DIR, f"{document_id}_{userid}.txt")
try:
with open(output_file, "w", encoding="utf-8") as f:
for t in unique_triplets:
f.write(f"{t['head']}\t{t['type']}\t{t['tail']}\t{t['head_type']}\t{t['tail_type']}\n")
logger.info(f"文档 {document_id}{len(unique_triplets)} 个三元组已保存到: {output_file}")
except Exception as e:
logger.error(f"保存文档 {document_id} 的三元组失败: {str(e)}")
return False
end_time = time.time()
logger.info(f"文档 {document_id} 三元组抽取完成,耗时: {end_time - start_time:.2f}")
return True
except Exception as e:
logger.error(f"抽取或保存三元组失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return False
if __name__ == "__main__":
# 测试用例
test_text = "知识图谱是一个结构化的语义知识库。深度学习是基于深层神经网络的机器学习子集。"
document_id = "testdoc123"
userid = "testuser1"
result = extract_and_save_triplets(test_text, document_id, userid)
print(f"抽取结果: {result}")