225 lines
8.0 KiB
Python
225 lines
8.0 KiB
Python
import os
|
||
import torch
|
||
import re
|
||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||
import logging
|
||
import yaml
|
||
import time
|
||
|
||
# 加载配置文件
|
||
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||
try:
|
||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
except Exception as e:
|
||
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(config['logging']['name'])
|
||
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||
handler.setFormatter(formatter)
|
||
logger.addHandler(handler)
|
||
|
||
# 三元组保存路径
|
||
TRIPLES_OUTPUT_DIR = "/share/wangmeihua/rag/triples"
|
||
os.makedirs(TRIPLES_OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 加载 mREBEL 模型和分词器
|
||
local_path = "/share/models/Babelscape/mrebel-large"
|
||
try:
|
||
tokenizer = AutoTokenizer.from_pretrained(local_path, src_lang="zh_CN", tgt_lang="tp_XX")
|
||
model = AutoModelForSeq2SeqLM.from_pretrained(local_path)
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
model = model.to(device)
|
||
triplet_id = tokenizer.convert_tokens_to_ids("<triplet>")
|
||
logger.debug(f"成功加载 mREBEL 模型,分词器 triplet_id: {triplet_id}")
|
||
except Exception as e:
|
||
logger.error(f"加载 mREBEL 模型失败: {str(e)}")
|
||
raise RuntimeError(f"加载 mREBEL 模型失败: {str(e)}")
|
||
|
||
# 优化生成参数
|
||
gen_kwargs = {
|
||
"max_length": 512,
|
||
"min_length": 10,
|
||
"length_penalty": 0.5,
|
||
"num_beams": 3,
|
||
"num_return_sequences": 1,
|
||
"no_repeat_ngram_size": 2,
|
||
"early_stopping": True,
|
||
"decoder_start_token_id": triplet_id,
|
||
}
|
||
|
||
def split_document(text: str, max_chunk_size: int = 150) -> list:
|
||
"""分割文档为语义完整的块"""
|
||
sentences = re.split(r'(?<=[。!?;\n])', text)
|
||
chunks = []
|
||
current_chunk = ""
|
||
|
||
for sentence in sentences:
|
||
if len(current_chunk) + len(sentence) <= max_chunk_size:
|
||
current_chunk += sentence
|
||
else:
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
current_chunk = sentence
|
||
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
|
||
return chunks
|
||
|
||
def extract_triplets_typed(text: str) -> list:
|
||
"""解析 mREBEL 生成文本,匹配 <triplet> <entity1> <type1> <entity2> <type2> <relation> 格式"""
|
||
triplets = []
|
||
logger.debug(f"原始生成文本: {text}")
|
||
|
||
# 分割标记
|
||
tokens = []
|
||
in_tag = False
|
||
buffer = ""
|
||
for char in text:
|
||
if char == '<':
|
||
in_tag = True
|
||
if buffer:
|
||
tokens.append(buffer.strip())
|
||
buffer = ""
|
||
buffer += char
|
||
elif char == '>':
|
||
in_tag = False
|
||
buffer += char
|
||
tokens.append(buffer.strip())
|
||
buffer = ""
|
||
else:
|
||
buffer += char
|
||
if buffer:
|
||
tokens.append(buffer.strip())
|
||
|
||
# 过滤特殊标记
|
||
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
|
||
tokens = [t for t in tokens if t not in special_tokens and t]
|
||
|
||
logger.debug(f"处理后标记: {tokens}")
|
||
|
||
# 解析三元组
|
||
i = 0
|
||
while i < len(tokens):
|
||
if tokens[i] == "<triplet>" and i + 5 < len(tokens):
|
||
entity1 = tokens[i + 1]
|
||
type1 = tokens[i + 2][1:-1] if tokens[i + 2].startswith("<") and tokens[i + 2].endswith(">") else ""
|
||
entity2 = tokens[i + 3]
|
||
type2 = tokens[i + 4][1:-1] if tokens[i + 4].startswith("<") and tokens[i + 4].endswith(">") else ""
|
||
relation = tokens[i + 5]
|
||
|
||
if entity1 and type1 and entity2 and type2 and relation:
|
||
triplets.append({
|
||
'head': entity1.strip(),
|
||
'head_type': type1,
|
||
'type': relation.strip(),
|
||
'tail': entity2.strip(),
|
||
'tail_type': type2
|
||
})
|
||
logger.debug(f"添加三元组: {entity1}({type1}) - {relation} - {entity2}({type2})")
|
||
i += 6
|
||
else:
|
||
i += 1
|
||
|
||
return triplets
|
||
|
||
def extract_and_save_triplets(text: str, document_id: str, userid: str) -> bool:
|
||
"""
|
||
从文本中抽取三元组并保存到指定路径。
|
||
|
||
参数:
|
||
text (str): 输入文本
|
||
document_id (str): 文档ID
|
||
userid (str): 用户ID
|
||
|
||
返回:
|
||
bool: 三元组抽取和保存是否成功
|
||
"""
|
||
try:
|
||
if not text or not document_id or not userid:
|
||
raise ValueError("text、document_id 和 userid 不能为空")
|
||
if "_" in document_id or "_" in userid:
|
||
raise ValueError("document_id 和 userid 不能包含下划线")
|
||
|
||
start_time = time.time()
|
||
logger.info(f"开始抽取文档 {document_id} 的三元组,userid: {userid}")
|
||
|
||
# 分割文本为语义块
|
||
text_chunks = split_document(text, max_chunk_size=150)
|
||
logger.debug(f"分割为 {len(text_chunks)} 个文本块")
|
||
|
||
# 处理所有文本块
|
||
all_triplets = []
|
||
for i, chunk in enumerate(text_chunks):
|
||
logger.debug(f"处理块 {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
|
||
|
||
# 分词
|
||
model_inputs = tokenizer(
|
||
chunk,
|
||
max_length=256,
|
||
padding=True,
|
||
truncation=True,
|
||
return_tensors="pt"
|
||
).to(device)
|
||
|
||
# 生成
|
||
try:
|
||
generated_tokens = model.generate(
|
||
model_inputs["input_ids"],
|
||
attention_mask=model_inputs["attention_mask"],
|
||
**gen_kwargs,
|
||
)
|
||
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
|
||
for idx, sentence in enumerate(decoded_preds):
|
||
logger.debug(f"块 {i + 1} 生成文本: {sentence}")
|
||
triplets = extract_triplets_typed(sentence)
|
||
if triplets:
|
||
logger.debug(f"块 {i + 1} 提取到 {len(triplets)} 个三元组")
|
||
all_triplets.extend(triplets)
|
||
except Exception as e:
|
||
logger.warning(f"处理块 {i + 1} 时出错: {str(e)}")
|
||
continue
|
||
|
||
# 去重
|
||
unique_triplets = []
|
||
seen = set()
|
||
for t in all_triplets:
|
||
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
|
||
if identifier not in seen:
|
||
seen.add(identifier)
|
||
unique_triplets.append(t)
|
||
|
||
# 保存结果
|
||
output_file = os.path.join(TRIPLES_OUTPUT_DIR, f"{document_id}_{userid}.txt")
|
||
try:
|
||
with open(output_file, "w", encoding="utf-8") as f:
|
||
for t in unique_triplets:
|
||
f.write(f"{t['head']}\t{t['type']}\t{t['tail']}\t{t['head_type']}\t{t['tail_type']}\n")
|
||
logger.info(f"文档 {document_id} 的 {len(unique_triplets)} 个三元组已保存到: {output_file}")
|
||
except Exception as e:
|
||
logger.error(f"保存文档 {document_id} 的三元组失败: {str(e)}")
|
||
return False
|
||
|
||
end_time = time.time()
|
||
logger.info(f"文档 {document_id} 三元组抽取完成,耗时: {end_time - start_time:.2f} 秒")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"抽取或保存三元组失败: {str(e)}")
|
||
import traceback
|
||
logger.debug(traceback.format_exc())
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
# 测试用例
|
||
test_text = "知识图谱是一个结构化的语义知识库。深度学习是基于深层神经网络的机器学习子集。"
|
||
document_id = "testdoc123"
|
||
userid = "testuser1"
|
||
result = extract_and_save_triplets(test_text, document_id, userid)
|
||
print(f"抽取结果: {result}") |