first commit
This commit is contained in:
commit
8cf8b975f7
BIN
Milvus/milvus.db
Normal file
BIN
Milvus/milvus.db
Normal file
Binary file not shown.
94
README.md
Normal file
94
README.md
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# 知识库服务器
|
||||||
|
本系统为不同的客户提供自我管理的知识库,并在知识库基础上提供知识检索
|
||||||
|
|
||||||
|
本系统提供API形式,为注册的服务器提供知识服支持,不面向最终客户
|
||||||
|
|
||||||
|
## 依赖
|
||||||
|
依赖[这些模块](requirements.txt)
|
||||||
|
|
||||||
|
## 安装部署
|
||||||
|
1. 创建rag用户
|
||||||
|
2. 登录rag用户
|
||||||
|
3. 执行以下命令
|
||||||
|
```
|
||||||
|
git clone git@git.kaiyuancloud.cn:yumoqing/rag
|
||||||
|
cd rag/script
|
||||||
|
./install.sh
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
将项目在用户根目录checkout
|
||||||
|
3.
|
||||||
|
## 功能
|
||||||
|
管理client系统的客户知识库,并提供知识查询
|
||||||
|
|
||||||
|
每个客户可以创建一到多个独立的知识库,为不同的业务场景提供知识库知识
|
||||||
|
|
||||||
|
知识库之间数据相互独立,互不干扰。
|
||||||
|
|
||||||
|
## http API
|
||||||
|
|
||||||
|
### add
|
||||||
|
增加知识库文档
|
||||||
|
|
||||||
|
#### path
|
||||||
|
/api/add
|
||||||
|
#### method
|
||||||
|
POST
|
||||||
|
#### 输入
|
||||||
|
name: authentication
|
||||||
|
value: Bears ${apikey}
|
||||||
|
score: headers
|
||||||
|
|
||||||
|
name: file_name
|
||||||
|
value: path of uploaded file
|
||||||
|
score: data
|
||||||
|
|
||||||
|
name: userid
|
||||||
|
value: userid of client system
|
||||||
|
score: data
|
||||||
|
|
||||||
|
name: kdbname
|
||||||
|
value: rag kdb name
|
||||||
|
score: data
|
||||||
|
|
||||||
|
#### 输出
|
||||||
|
|
||||||
|
### query
|
||||||
|
查询知识库
|
||||||
|
|
||||||
|
#### path
|
||||||
|
/api/query
|
||||||
|
#### method
|
||||||
|
POST
|
||||||
|
|
||||||
|
#### 输入
|
||||||
|
name: authentication
|
||||||
|
value: Bears ${apikey}
|
||||||
|
score: headers
|
||||||
|
|
||||||
|
name: prompt
|
||||||
|
value: ${prompt}
|
||||||
|
score: data
|
||||||
|
|
||||||
|
name: userid
|
||||||
|
value: ${userid}
|
||||||
|
score: data
|
||||||
|
|
||||||
|
name: kdbname
|
||||||
|
value: ${kdbname}
|
||||||
|
score: data
|
||||||
|
|
||||||
|
|
||||||
|
#### 输出
|
||||||
|
```
|
||||||
|
{
|
||||||
|
total:返回记录条数,
|
||||||
|
rows:返回记录内容
|
||||||
|
}
|
||||||
|
rows有以下属性
|
||||||
|
content:文本内容
|
||||||
|
distances:距离
|
||||||
|
source:文档path
|
||||||
|
```
|
||||||
|
|
||||||
BIN
app/.query.py.swp
Normal file
BIN
app/.query.py.swp
Normal file
Binary file not shown.
BIN
app/__pycache__/rag.cpython-310.pyc
Normal file
BIN
app/__pycache__/rag.cpython-310.pyc
Normal file
Binary file not shown.
57
app/embed.py
Normal file
57
app/embed.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from langchain_community.document_loaders.csv_loader import CSVLoader
|
||||||
|
from langchain_community.document_loaders.text import TextLoader
|
||||||
|
from langchain_community.document_loaders import UnstructuredPDFLoader
|
||||||
|
from langchain_community.document_loaders import UnstructuredWordDocumentLoader
|
||||||
|
from langchain_community.document_loaders import UnstructuredExcelLoader
|
||||||
|
from langchain_community.document_loaders import UnstructuredPowerPointLoader
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from appPublic.log import debug
|
||||||
|
from appPublic.uniqueID import getID
|
||||||
|
|
||||||
|
from get_vector_db import get_vector_db
|
||||||
|
|
||||||
|
TEMP_FOLDER = os.getenv('TEMP_FOLDER', './_temp')
|
||||||
|
|
||||||
|
# Function to check if the uploaded file is allowed (only PDF files)
|
||||||
|
def allowed_file(filename):
|
||||||
|
allowed_file_subffix = ['pdf','doc', 'docx','xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt']
|
||||||
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_file_subffix
|
||||||
|
|
||||||
|
# Function to load and split the data from the PDF file
|
||||||
|
def load_and_split_data(file_path):
|
||||||
|
# Load the PDF file and split the data into chunks
|
||||||
|
data = None
|
||||||
|
if file_path.lower().endswith('.pdf'):
|
||||||
|
loader = UnstructuredPDFLoader(file_path=file_path)
|
||||||
|
elif file_path.lower().endswith('.docx') or file_path.lower().endswith('.doc'):
|
||||||
|
loader = UnstructuredWordDocumentLoader(file_path=file_path)
|
||||||
|
elif file_path.lower().endswith('.pptx') or file_path.lower().endswith('.pptx'):
|
||||||
|
loader = UnstructuredPowerPointLoader(file_path=file_path)
|
||||||
|
elif file_path.lower().endswith('.xlsx') or file_path.lower().endswith('.xls'):
|
||||||
|
loader = UnstructuredExcelLoader(file_path=file_path)
|
||||||
|
elif file_path.lower().endswith('.csv'):
|
||||||
|
loader = CSVLoader(file_path=file_path)
|
||||||
|
else:
|
||||||
|
loader = TextLoader(file_path=file_path)
|
||||||
|
data = loader.load()
|
||||||
|
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
|
||||||
|
chunks = text_splitter.split_documents(data)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Main function to handle the embedding process
|
||||||
|
def embed(file_path, userid, kdbname):
|
||||||
|
if allowed_file(file_path):
|
||||||
|
chunks = load_and_split_data(file_path)
|
||||||
|
debug(f'{chunks=}')
|
||||||
|
db = get_vector_db(userid, kdbname)
|
||||||
|
db.add(
|
||||||
|
documents=[c.page_content for c in chunks],
|
||||||
|
metadatas=[c.metadata for c in chunks],
|
||||||
|
ids=[getID() for c in chunks]
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
0
app/init.py
Normal file
0
app/init.py
Normal file
22
app/ragapp.py
Normal file
22
app/ragapp.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.configuredServer import ConfiguredServer
|
||||||
|
from ahserver.webapp import webapp
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from filemgr.init import load_filemgr
|
||||||
|
from rbac.init import load_rbac
|
||||||
|
from appbase.init import load_appbase
|
||||||
|
from rag.init import load_rag
|
||||||
|
|
||||||
|
def get_module_dbname(name):
|
||||||
|
return 'sage'
|
||||||
|
|
||||||
|
def init():
|
||||||
|
load_rag()
|
||||||
|
load_appbase()
|
||||||
|
load_filemgr()
|
||||||
|
env = ServerEnv()
|
||||||
|
env.get_module_dbname = get_module_dbname
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
webapp(init)
|
||||||
|
|
||||||
2
build/lib/rag/__init__.py
Normal file
2
build/lib/rag/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .version import __version__
|
||||||
|
|
||||||
138
build/lib/rag/deletefile.py
Normal file
138
build/lib/rag/deletefile.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from pymilvus import connections, Collection, utility
|
||||||
|
from vector import initialize_milvus_connection
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def delete_document(db_type: str, userid: str, filename: str) -> bool:
|
||||||
|
"""
|
||||||
|
根据 db_type、userid 和 filename 删除用户的指定文件数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
db_type (str): 数据库类型(如 'textdb', 'pptdb')
|
||||||
|
userid (str): 用户 ID
|
||||||
|
filename (str): 文件名(如 'test.docx')
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 删除是否成功
|
||||||
|
|
||||||
|
异常:
|
||||||
|
ValueError: 参数无效
|
||||||
|
RuntimeError: 数据库操作失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 参数验证
|
||||||
|
if not db_type or "_" in db_type:
|
||||||
|
raise ValueError("db_type 不能为空且不能包含下划线")
|
||||||
|
if not userid or "_" in userid:
|
||||||
|
raise ValueError("userid 不能为空且不能包含下划线")
|
||||||
|
if not filename:
|
||||||
|
raise ValueError("filename 不能为空")
|
||||||
|
if len(db_type) > 100 or len(userid) > 100 or len(filename) > 255:
|
||||||
|
raise ValueError("db_type、userid 或 filename 的长度超出限制")
|
||||||
|
|
||||||
|
# 初始化 Milvus 连接
|
||||||
|
initialize_milvus_connection()
|
||||||
|
logger.debug(f"已连接到 Milvus Lite,路径: {MILVUS_DB_PATH}")
|
||||||
|
|
||||||
|
# 检查集合是否存在
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 加载集合
|
||||||
|
try:
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"加载集合: {collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||||
|
|
||||||
|
# 查询匹配的 document_id
|
||||||
|
expr = f"userid == '{userid}' and filename == '{filename}'"
|
||||||
|
logger.debug(f"查询表达式: {expr}")
|
||||||
|
try:
|
||||||
|
results = collection.query(
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["document_id"],
|
||||||
|
limit=1000
|
||||||
|
)
|
||||||
|
if not results:
|
||||||
|
logger.warning(f"没有找到 userid={userid}, filename={filename} 的记录")
|
||||||
|
return False
|
||||||
|
document_ids = list(set(result["document_id"] for result in results if "document_id" in result))
|
||||||
|
logger.debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询 document_id 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"查询失败: {str(e)}")
|
||||||
|
|
||||||
|
# 执行删除
|
||||||
|
total_deleted = 0
|
||||||
|
for doc_id in document_ids:
|
||||||
|
try:
|
||||||
|
delete_expr = f"userid == '{userid}' and document_id == '{doc_id}'"
|
||||||
|
logger.debug(f"删除表达式: {delete_expr}")
|
||||||
|
delete_result = collection.delete(delete_expr)
|
||||||
|
deleted_count = delete_result.delete_count
|
||||||
|
total_deleted += deleted_count
|
||||||
|
logger.info(f"成功删除 document_id={doc_id} 的 {deleted_count} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除 document_id={doc_id} 失败: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if total_deleted == 0:
|
||||||
|
logger.warning(f"没有删除任何记录,userid={userid}, filename={filename}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"总计删除 {total_deleted} 条记录,userid={userid}, filename={filename}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"参数验证失败: {str(ve)}")
|
||||||
|
return False
|
||||||
|
except RuntimeError as re:
|
||||||
|
logger.error(f"数据库操作失败: {str(re)}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除文件失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
connections.disconnect("default")
|
||||||
|
logger.debug("已断开 Milvus 连接")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试用例
|
||||||
|
db_type = "textdb"
|
||||||
|
userid = "testuser4"
|
||||||
|
filename = "聚类结果1.xlsx"
|
||||||
|
|
||||||
|
logger.info(f"测试:删除 userid={userid}, filename={filename} 的文件")
|
||||||
|
result = delete_document(db_type, userid, filename)
|
||||||
|
print(f"删除结果: {result}")
|
||||||
178
build/lib/rag/embed.py
Normal file
178
build/lib/rag/embed.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from pymilvus import connections
|
||||||
|
from .vector import get_vector_db
|
||||||
|
from filetxt.loader import fileloader
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def load_and_split_data(file_path: str, userid: str) -> List[Document]:
|
||||||
|
"""
|
||||||
|
加载文件,分片并生成带有元数据的 Document 对象。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): 文件路径
|
||||||
|
userid (str): 用户ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[Document]: 分片后的文档列表
|
||||||
|
|
||||||
|
异常:
|
||||||
|
ValueError: 文件或参数无效
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证文件
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise ValueError(f"文件 {file_path} 不存在")
|
||||||
|
if os.path.getsize(file_path) == 0:
|
||||||
|
raise ValueError(f"文件 {file_path} 为空")
|
||||||
|
logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节")
|
||||||
|
ext = file_path.rsplit('.', 1)[1].lower()
|
||||||
|
logger.debug(f"文件扩展名: {ext}")
|
||||||
|
|
||||||
|
# 使用 fileloader 加载文件内容
|
||||||
|
logger.debug("开始加载文件")
|
||||||
|
text = fileloader(file_path)
|
||||||
|
if not text or not text.strip():
|
||||||
|
raise ValueError(f"文件 {file_path} 加载为空")
|
||||||
|
|
||||||
|
# 创建单个 Document 对象
|
||||||
|
document = Document(page_content=text)
|
||||||
|
logger.debug(f"加载完成,生成 1 个文档")
|
||||||
|
|
||||||
|
# 分割文本
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=2000,
|
||||||
|
chunk_overlap=200,
|
||||||
|
length_function=len,
|
||||||
|
)
|
||||||
|
chunks = text_splitter.split_documents([document])
|
||||||
|
logger.debug(f"分割完成,生成 {len(chunks)} 个文档块")
|
||||||
|
|
||||||
|
# 为整个文件生成唯一的 document_id
|
||||||
|
document_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 添加元数据,确保包含所有必需字段
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
upload_time = datetime.now().isoformat()
|
||||||
|
documents = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk.metadata.update({
|
||||||
|
'userid': userid,
|
||||||
|
'document_id': document_id,
|
||||||
|
'filename': filename,
|
||||||
|
'file_path': file_path,
|
||||||
|
'upload_time': upload_time,
|
||||||
|
'file_type': ext,
|
||||||
|
'chunk_index': i, # 可选,追踪分片顺序
|
||||||
|
'source': file_path, # 可选,追踪文件来源
|
||||||
|
})
|
||||||
|
# 验证元数据完整性
|
||||||
|
required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type']
|
||||||
|
if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields):
|
||||||
|
raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}")
|
||||||
|
documents.append(chunk)
|
||||||
|
logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}")
|
||||||
|
|
||||||
|
logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块,document_id: {document_id}")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
raise ValueError(f"加载或分割文件失败: {str(e)}")
|
||||||
|
|
||||||
|
def embed(file_path: str, userid: str, db_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
嵌入文件到 Milvus 向量数据库。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): 文件路径
|
||||||
|
userid (str): 用户ID
|
||||||
|
db_type (str): 数据库类型
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 嵌入是否成功
|
||||||
|
|
||||||
|
异常:
|
||||||
|
ValueError: 参数无效
|
||||||
|
RuntimeError: 数据库操作失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证输入
|
||||||
|
if not userid or not db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能为空")
|
||||||
|
if "_" in userid:
|
||||||
|
raise ValueError("userid 不能包含下划线")
|
||||||
|
if "_" in db_type:
|
||||||
|
raise ValueError("db_type 不能包含下划线")
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise ValueError(f"文件 {file_path} 不存在")
|
||||||
|
|
||||||
|
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
|
||||||
|
ext = file_path.rsplit('.', 1)[1].lower()
|
||||||
|
if ext not in supported_formats:
|
||||||
|
logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}")
|
||||||
|
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||||
|
|
||||||
|
# 加载并分割文件
|
||||||
|
logger.info(f"开始处理文件 {file_path},userid: {userid},db_type: {db_type}")
|
||||||
|
chunks = load_and_split_data(file_path, userid)
|
||||||
|
if not chunks:
|
||||||
|
logger.error(f"文件 {file_path} 未生成任何文档块")
|
||||||
|
raise ValueError("未生成任何文档块")
|
||||||
|
|
||||||
|
logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块")
|
||||||
|
logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}")
|
||||||
|
|
||||||
|
# 插入到 Milvus
|
||||||
|
db = get_vector_db(userid, db_type, documents=chunks)
|
||||||
|
if not db:
|
||||||
|
logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}")
|
||||||
|
raise RuntimeError(f"数据库操作失败")
|
||||||
|
|
||||||
|
logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}")
|
||||||
|
return False
|
||||||
|
except RuntimeError as re:
|
||||||
|
logger.error(f"嵌入文件 {file_path} 失败: {str(re)}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入文件 {file_path} 失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_file = "/share/wangmeihua/rag/data/test.txt"
|
||||||
|
userid = "testuser4"
|
||||||
|
db_type = "textdb"
|
||||||
|
result = embed(test_file, userid, db_type)
|
||||||
|
print(f"嵌入结果: {result}")
|
||||||
15
build/lib/rag/init.py
Normal file
15
build/lib/rag/init.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from appPublic.worker import awaitify
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from .kdb import add_kdb, add_dir, add_doc, get_all_docs
|
||||||
|
from .query import search_query
|
||||||
|
from .embed import embed
|
||||||
|
def load_rag():
|
||||||
|
env = ServerEnv()
|
||||||
|
env.add_kdb = add_kdb
|
||||||
|
env.query = awaitify(search_query)
|
||||||
|
env.embed = awaitify(embed)
|
||||||
|
env.add_dir = add_dir
|
||||||
|
env.add_doc = add_doc
|
||||||
|
env.get_all_docs = get_all_docs
|
||||||
|
|
||||||
|
|
||||||
81
build/lib/rag/kdb.py
Normal file
81
build/lib/rag/kdb.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
|
||||||
|
from traceback import format_exc
|
||||||
|
from appPublic.uniqueID import getID
|
||||||
|
from appPublic.timeUtils import curDateString
|
||||||
|
from appPublic.dictObject import DictObject
|
||||||
|
from sqlor.dbpools import DBPools
|
||||||
|
from ahserver.serverenv import get_serverenv
|
||||||
|
from ahserver.filestorage import FileStorage
|
||||||
|
|
||||||
|
async def add_kdb(kdb:dict) -> None:
|
||||||
|
"""
|
||||||
|
添加知识库
|
||||||
|
"""
|
||||||
|
kdb = DictObject(**kdb)
|
||||||
|
kdb.parentid=None
|
||||||
|
if kdb.id is None:
|
||||||
|
kdb.id = getID()
|
||||||
|
kdb.entity_type = '0'
|
||||||
|
kdb.create_date = curDateString()
|
||||||
|
if kdb.orgid is None:
|
||||||
|
e = Exception(f'Can not add none orgid kdb')
|
||||||
|
exception(f'{e}\n{format_exc()}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
f = get_serverenv('get_module_dbname')
|
||||||
|
dbname = f('rag')
|
||||||
|
db = DBPools()
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
await C('kdb', kdb.copy())
|
||||||
|
|
||||||
|
async def add_dir(kdb:dict) -> None:
|
||||||
|
"""
|
||||||
|
添加子目录
|
||||||
|
"""
|
||||||
|
kdb = DictObject(**kdb)
|
||||||
|
if kdb.parentid is None:
|
||||||
|
e = Exception(f'Can not add root folder')
|
||||||
|
exception(f'{e}\n{format_exc()}')
|
||||||
|
raise e
|
||||||
|
if kdb.id is None:
|
||||||
|
kdb.id = getID()
|
||||||
|
kdb.entity_type = '1'
|
||||||
|
kdb.create_date = curDateString()
|
||||||
|
f = get_serverenv('get_module_dbname')
|
||||||
|
dbname = f('rag')
|
||||||
|
db = DBPools()
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
await C('kdb', kdb.copy())
|
||||||
|
|
||||||
|
async def add_doc(doc:dict) -> None:
|
||||||
|
"""
|
||||||
|
添加文档
|
||||||
|
"""
|
||||||
|
doc = DictObject(**doc)
|
||||||
|
if doc.parentid is None:
|
||||||
|
e = Exception(f'Can not add root document')
|
||||||
|
exception(f'{e}\n{format_exc()}')
|
||||||
|
raise e
|
||||||
|
if doc.id is None:
|
||||||
|
doc.id = getID()
|
||||||
|
fs = FileStorage()
|
||||||
|
doc.realpath = fs.realPath(doc.webpath)
|
||||||
|
doc.create_date = curDateString()
|
||||||
|
f = get_serverenv('get_module_dbname')
|
||||||
|
dbname = f('rag')
|
||||||
|
db = DBPools()
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
await C('doc', doc.copy())
|
||||||
|
|
||||||
|
async def get_all_docs(sor, kdbid):
|
||||||
|
"""
|
||||||
|
获取所有kdbid下的文档,含子目录的
|
||||||
|
"""
|
||||||
|
docs = await sor.R('doc', {'parentid':kdbid})
|
||||||
|
kdbs = await sor.R('kdb', {'parentid':kdbid})
|
||||||
|
for kdb in kdbs:
|
||||||
|
docs1 = await get_all_docs(kdb.id)
|
||||||
|
docs += docs1
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
180
build/lib/rag/query.py
Normal file
180
build/lib/rag/query.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict
|
||||||
|
from pymilvus import connections, Collection, utility
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from .vector import get_vector_db, initialize_milvus_connection, cleanup_milvus_connection
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def search_query(query: str, userid: str, db_type: str, limit: int = 10, offset: int = 0) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
根据用户输入的查询文本,在指定 db_type 的知识库中搜索与 userid 相关的文档。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query (str): 用户输入的查询文本
|
||||||
|
userid (str): 用户ID,用于过滤
|
||||||
|
db_type (str): 数据库类型(例如 'textdb')
|
||||||
|
limit (int): 返回的最大结果数,默认为 10
|
||||||
|
offset (int): 偏移量,用于分页,默认为 0
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[Dict]: 搜索结果,每个元素为包含 text 和 metadata 的字典
|
||||||
|
|
||||||
|
异常:
|
||||||
|
ValueError: 参数无效
|
||||||
|
RuntimeError: 模型加载或 Milvus 操作失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 参数验证
|
||||||
|
if not query:
|
||||||
|
raise ValueError("查询文本不能为空")
|
||||||
|
if not userid or not db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能为空")
|
||||||
|
if "_" in userid or "_" in db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
|
if len(userid) > 100 or len(db_type) > 100:
|
||||||
|
raise ValueError("userid 和 db_type 的长度应小于 100")
|
||||||
|
if limit <= 0 or limit > 16384:
|
||||||
|
raise ValueError("limit 必须在 1 到 16384 之间")
|
||||||
|
if offset < 0:
|
||||||
|
raise ValueError("offset 不能为负数")
|
||||||
|
if limit + offset > 16384:
|
||||||
|
raise ValueError("limit + offset 不能超过 16384")
|
||||||
|
|
||||||
|
# 初始化嵌入模型
|
||||||
|
model_path = TEXT_EMBEDDING_MODEL
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise ValueError(f"模型路径 {model_path} 不存在")
|
||||||
|
|
||||||
|
embedding = HuggingFaceEmbeddings(
|
||||||
|
model_name=model_path,
|
||||||
|
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
test_vector = embedding.embed_query("test")
|
||||||
|
if len(test_vector) != 1024:
|
||||||
|
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
|
||||||
|
logger.debug("嵌入模型加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
|
||||||
|
# 将查询转换为向量
|
||||||
|
query_vector = embedding.embed_query(query)
|
||||||
|
logger.debug(f"查询向量维度: {len(query_vector)}")
|
||||||
|
|
||||||
|
# 连接到 Milvus
|
||||||
|
initialize_milvus_connection()
|
||||||
|
|
||||||
|
# 检查集合是否存在
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 加载集合
|
||||||
|
try:
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"加载集合: {collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||||
|
|
||||||
|
# 构造搜索参数
|
||||||
|
search_params = {
|
||||||
|
"metric_type": "COSINE", # 与 vector.py 一致
|
||||||
|
"params": {"nprobe": 10} # 优化搜索性能
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构造过滤表达式
|
||||||
|
expr = f"userid == '{userid}'"
|
||||||
|
logger.debug(f"搜索参数: {search_params}, 表达式: {expr}, limit: {limit}, offset: {offset}")
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
try:
|
||||||
|
results = collection.search(
|
||||||
|
data=[query_vector],
|
||||||
|
anns_field="vector",
|
||||||
|
param=search_params,
|
||||||
|
limit=limit,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"搜索失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"搜索失败: {str(e)}")
|
||||||
|
|
||||||
|
# 处理搜索结果
|
||||||
|
search_results = []
|
||||||
|
for hits in results:
|
||||||
|
for hit in hits:
|
||||||
|
metadata = {
|
||||||
|
"userid": hit.entity.get("userid"),
|
||||||
|
"document_id": hit.entity.get("document_id"),
|
||||||
|
"filename": hit.entity.get("filename"),
|
||||||
|
"file_path": hit.entity.get("file_path"),
|
||||||
|
"upload_time": hit.entity.get("upload_time"),
|
||||||
|
"file_type": hit.entity.get("file_type")
|
||||||
|
}
|
||||||
|
result = {
|
||||||
|
"text": hit.entity.get("text"),
|
||||||
|
"distance": hit.distance,
|
||||||
|
"metadata": metadata
|
||||||
|
}
|
||||||
|
search_results.append(result)
|
||||||
|
logger.debug(f"命中: text: {result['text'][:200]}..., 距离: {hit.distance}, 元数据: {metadata}")
|
||||||
|
|
||||||
|
logger.debug(f"搜索完成,返回 {len(search_results)} 条结果")
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"搜索失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cleanup_milvus_connection()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试代码
|
||||||
|
query = "知识图谱的知识融合是什么?"
|
||||||
|
userid = "testuser1"
|
||||||
|
db_type = "textdb"
|
||||||
|
limit = 2
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = search_query(query, userid, db_type, limit, offset)
|
||||||
|
print(f"搜索结果 ({len(results)} 条):")
|
||||||
|
for idx, result in enumerate(results, 1):
|
||||||
|
print(f"结果 {idx}:")
|
||||||
|
print(f"内容: {result['text'][:200]}...")
|
||||||
|
print(f"距离: {result['distance']}")
|
||||||
|
print(f"元数据: {result['metadata']}")
|
||||||
|
print("-" * 50)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索失败: {str(e)}")
|
||||||
53
build/lib/rag/rag.bak.py
Normal file
53
build/lib/rag/rag.bak.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.configuredServer import ConfiguredServer
|
||||||
|
from ahserver.webapp import webapp
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from query import search_query
|
||||||
|
from embed import embed
|
||||||
|
from deletefile import delete_document
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(os.path.expanduser('~/rag/logs/rag.log'), encoding='utf-8'),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_dbname(name):
|
||||||
|
"""
|
||||||
|
获取默认数据库名称,优先使用环境变量 RAG_DB_TYPE,未设置时返回 'textdb'。
|
||||||
|
"""
|
||||||
|
return os.getenv('RAG_DB_TYPE', 'textdb')
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
"""
|
||||||
|
初始化 RAG 系统,绑定核心函数到 ServerEnv。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info("初始化 RAG 系统")
|
||||||
|
g = ServerEnv()
|
||||||
|
|
||||||
|
# 绑定核心函数为异步版本
|
||||||
|
g.embed = awaitify(embed)
|
||||||
|
g.query = awaitify(search_query)
|
||||||
|
g.delete = awaitify(delete_document)
|
||||||
|
g.get_module_dbname = get_module_dbname
|
||||||
|
|
||||||
|
logger.info("RAG 系统初始化完成")
|
||||||
|
return g
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logger.info("启动 RAG Web 服务器")
|
||||||
|
webapp(init)
|
||||||
53
build/lib/rag/rag.py
Normal file
53
build/lib/rag/rag.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.configuredServer import ConfiguredServer
|
||||||
|
from ahserver.webapp import webapp
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from query import search_query
|
||||||
|
from embed import embed
|
||||||
|
from deletefile import delete_document
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(os.path.expanduser('~/rag/logs/rag.log'), encoding='utf-8'),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_dbname(name):
|
||||||
|
"""
|
||||||
|
获取默认数据库名称,优先使用环境变量 RAG_DB_TYPE,未设置时返回 'textdb'。
|
||||||
|
"""
|
||||||
|
return os.getenv('RAG_DB_TYPE', 'textdb')
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
"""
|
||||||
|
初始化 RAG 系统,绑定核心函数到 ServerEnv。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info("初始化 RAG 系统")
|
||||||
|
g = ServerEnv()
|
||||||
|
|
||||||
|
# 绑定核心函数为异步版本
|
||||||
|
g.embed = awaitify(embed)
|
||||||
|
g.query = awaitify(search_query)
|
||||||
|
g.delete = awaitify(delete_document)
|
||||||
|
g.get_module_dbname = get_module_dbname
|
||||||
|
|
||||||
|
logger.info("RAG 系统初始化完成")
|
||||||
|
return g
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logger.info("启动 RAG Web 服务器")
|
||||||
|
webapp(init)
|
||||||
539
build/lib/rag/vector.py
Normal file
539
build/lib/rag/vector.py
Normal file
@ -0,0 +1,539 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
|
||||||
|
from langchain_milvus import Milvus
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def ensure_milvus_directory() -> None:
|
||||||
|
"""确保 Milvus 数据库目录存在"""
|
||||||
|
db_dir = os.path.dirname(MILVUS_DB_PATH)
|
||||||
|
if not os.path.exists(db_dir):
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
logger.debug(f"创建 Milvus 目录: {db_dir}")
|
||||||
|
if not os.access(db_dir, os.W_OK):
|
||||||
|
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
|
||||||
|
|
||||||
|
def initialize_milvus_connection() -> None:
|
||||||
|
"""初始化 Milvus 连接,确保单一连接"""
|
||||||
|
try:
|
||||||
|
if not connections.has_connection("default"):
|
||||||
|
connections.connect("default", uri=MILVUS_DB_PATH)
|
||||||
|
logger.debug(f"已连接到 Milvus Lite,路径: {MILVUS_DB_PATH}")
|
||||||
|
else:
|
||||||
|
logger.debug("已存在 Milvus 连接,跳过重复连接")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"连接 Milvus 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
|
||||||
|
|
||||||
|
def cleanup_milvus_connection() -> None:
|
||||||
|
"""清理 Milvus 连接,确保资源释放"""
|
||||||
|
try:
|
||||||
|
if connections.has_connection("default"):
|
||||||
|
connections.disconnect("default")
|
||||||
|
logger.debug("已断开 Milvus 连接")
|
||||||
|
time.sleep(3)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
|
||||||
|
|
||||||
|
def get_vector_db(userid: str, db_type: str, documents: List[Document]) -> Milvus:
|
||||||
|
"""
|
||||||
|
初始化或访问 Milvus Lite 向量数据库集合,按 db_type 组织,利用 userid 区分用户,document_id 区分文档,并插入文档。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 参数验证
|
||||||
|
if not userid or not db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能为空")
|
||||||
|
if "_" in userid or "_" in db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
|
if len(userid) > 100 or len(db_type) > 100:
|
||||||
|
raise ValueError("userid 和 db_type 的长度应小于 100")
|
||||||
|
if not documents or not all(isinstance(doc, Document) for doc in documents):
|
||||||
|
raise ValueError("documents 不能为空且必须是 Document 对象列表")
|
||||||
|
required_fields = ["userid", "document_id", "filename", "file_path", "upload_time", "file_type"]
|
||||||
|
for doc in documents:
|
||||||
|
if not all(field in doc.metadata and doc.metadata[field] for field in required_fields):
|
||||||
|
raise ValueError(f"文档元数据缺少必需字段或字段值为空: {doc.metadata}")
|
||||||
|
if doc.metadata["userid"] != userid:
|
||||||
|
raise ValueError(f"文档元数据的 userid {doc.metadata['userid']} 与输入 userid {userid} 不一致")
|
||||||
|
|
||||||
|
ensure_milvus_directory()
|
||||||
|
initialize_milvus_connection()
|
||||||
|
|
||||||
|
# 初始化嵌入模型
|
||||||
|
model_path = TEXT_EMBEDDING_MODEL
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise ValueError(f"模型路径 {model_path} 不存在")
|
||||||
|
|
||||||
|
embedding = HuggingFaceEmbeddings(
|
||||||
|
model_name=model_path,
|
||||||
|
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
test_vector = embedding.embed_query("test")
|
||||||
|
if len(test_vector) != 1024:
|
||||||
|
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
|
||||||
|
logger.debug(f"嵌入模型加载成功,输出维度: {len(test_vector)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载模型失败: {str(e)}")
|
||||||
|
|
||||||
|
# 集合名称
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if len(collection_name) > 255:
|
||||||
|
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
|
||||||
|
logger.debug(f"集合名称: {collection_name}")
|
||||||
|
|
||||||
|
# 定义 schema,包含所有固定字段
|
||||||
|
fields = [
|
||||||
|
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True),
|
||||||
|
FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100),
|
||||||
|
FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36),
|
||||||
|
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||||
|
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
|
||||||
|
FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255),
|
||||||
|
FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024),
|
||||||
|
FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64),
|
||||||
|
FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64),
|
||||||
|
]
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=fields,
|
||||||
|
description=f"{db_type} 数据集合,跨用户使用,包含 document_id 和元数据字段",
|
||||||
|
auto_id=True,
|
||||||
|
primary_field="pk",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查集合是否存在
|
||||||
|
if utility.has_collection(collection_name):
|
||||||
|
try:
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
existing_schema = collection.schema
|
||||||
|
expected_fields = {f.name for f in fields}
|
||||||
|
actual_fields = {f.name for f in existing_schema.fields}
|
||||||
|
vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None)
|
||||||
|
|
||||||
|
schema_compatible = False
|
||||||
|
if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR:
|
||||||
|
dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None
|
||||||
|
schema_compatible = dim == 1024
|
||||||
|
logger.debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, "
|
||||||
|
f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else '无'}, "
|
||||||
|
f"dim={dim if dim is not None else '未定义'}")
|
||||||
|
if not schema_compatible:
|
||||||
|
logger.warning(f"集合 {collection_name} 的 schema 不兼容,原因: "
|
||||||
|
f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or '无'}, "
|
||||||
|
f"vector_field: {vector_field is not None}, "
|
||||||
|
f"dtype: {vector_field.dtype if vector_field else '无'}, "
|
||||||
|
f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}")
|
||||||
|
utility.drop_collection(collection_name)
|
||||||
|
else:
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"集合 {collection_name} 已存在并加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||||
|
|
||||||
|
# 创建新集合
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
try:
|
||||||
|
collection = Collection(collection_name, schema)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="vector",
|
||||||
|
index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="userid",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="document_id",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="filename",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="file_path",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="upload_time",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="file_type",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"成功创建并加载集合: {collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建集合 {collection_name} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"创建集合失败: {str(e)}")
|
||||||
|
|
||||||
|
# 初始化 Milvus 向量存储
|
||||||
|
try:
|
||||||
|
vector_store = Milvus(
|
||||||
|
embedding_function=embedding,
|
||||||
|
collection_name=collection_name,
|
||||||
|
connection_args={"uri": MILVUS_DB_PATH},
|
||||||
|
drop_old=False,
|
||||||
|
auto_id=True,
|
||||||
|
primary_field="pk",
|
||||||
|
)
|
||||||
|
logger.debug(f"成功初始化 Milvus 向量存储: {collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"初始化向量存储失败: {str(e)}")
|
||||||
|
|
||||||
|
# 插入文档
|
||||||
|
try:
|
||||||
|
logger.debug(f"正在为 userid {userid} 插入 {len(documents)} 个文档到 {collection_name}")
|
||||||
|
for doc in documents:
|
||||||
|
logger.debug(f"插入文档元数据: {doc.metadata}")
|
||||||
|
vector_store.add_documents(documents=documents)
|
||||||
|
logger.debug(f"成功插入 {len(documents)} 个文档")
|
||||||
|
|
||||||
|
# 立即查询验证
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
results = collection.query(
|
||||||
|
expr=f"userid == '{userid}'",
|
||||||
|
output_fields=["pk", "text", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
for result in results:
|
||||||
|
logger.debug(f"插入后查询结果: pk={result['pk']}, document_id={result['document_id']}, "
|
||||||
|
f"metadata={{'filename': '{result['filename']}', 'file_path': '{result['file_path']}', "
|
||||||
|
f"'upload_time': '{result['upload_time']}', 'file_type': '{result['file_type']}'}}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"插入文档失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"插入文档失败: {str(e)}")
|
||||||
|
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cleanup_milvus_connection()
|
||||||
|
|
||||||
|
def get_document_mapping(userid: str, db_type: str) -> Dict[str, Dict]:
|
||||||
|
"""
|
||||||
|
获取指定 userid 和 db_type 下的 document_id 与元数据的映射。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not userid or "_" in userid:
|
||||||
|
raise ValueError("userid 不能为空且不能包含下划线")
|
||||||
|
if not db_type or "_" in db_type:
|
||||||
|
raise ValueError("db_type 不能为空且不能包含下划线")
|
||||||
|
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
results = collection.query(
|
||||||
|
expr=f"userid == '{userid}'",
|
||||||
|
output_fields=["userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
limit=100
|
||||||
|
)
|
||||||
|
mapping = {}
|
||||||
|
for result in results:
|
||||||
|
doc_id = result.get("document_id")
|
||||||
|
if doc_id:
|
||||||
|
mapping[doc_id] = {
|
||||||
|
"userid": result.get("userid", ""),
|
||||||
|
"filename": result.get("filename", ""),
|
||||||
|
"file_path": result.get("file_path", ""),
|
||||||
|
"upload_time": result.get("upload_time", ""),
|
||||||
|
"file_type": result.get("file_type", "")
|
||||||
|
}
|
||||||
|
logger.debug(f"document_id: {doc_id}, metadata: {mapping[doc_id]}")
|
||||||
|
|
||||||
|
logger.debug(f"找到 {len(mapping)} 个文档的映射")
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取文档映射失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"获取文档映射失败: {str(e)}")
|
||||||
|
|
||||||
|
def list_user_collections() -> Dict[str, Dict]:
|
||||||
|
"""
|
||||||
|
列出所有数据库类型(db_type)及其包含的用户(userid)与对应的文档(document_id)映射。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ensure_milvus_directory()
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collections = utility.list_collections()
|
||||||
|
|
||||||
|
db_types_with_data = {}
|
||||||
|
for col in collections:
|
||||||
|
if col.startswith("ragdb_"):
|
||||||
|
db_type = col[len("ragdb_"):]
|
||||||
|
logger.debug(f"处理集合: {col} (db_type: {db_type})")
|
||||||
|
|
||||||
|
collection = Collection(col)
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
batch_size = 1000
|
||||||
|
offset = 0
|
||||||
|
user_document_map = {}
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
results = collection.query(
|
||||||
|
expr="",
|
||||||
|
output_fields=["userid", "document_id"],
|
||||||
|
limit=batch_size,
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
if not results:
|
||||||
|
break
|
||||||
|
for result in results:
|
||||||
|
userid = result.get("userid")
|
||||||
|
doc_id = result.get("document_id")
|
||||||
|
if userid and doc_id:
|
||||||
|
if userid not in user_document_map:
|
||||||
|
user_document_map[userid] = set()
|
||||||
|
user_document_map[userid].add(doc_id)
|
||||||
|
offset += batch_size
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询集合 {col} 的用户和文档失败: {str(e)}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 转换为列表以便序列化
|
||||||
|
user_document_map = {uid: list(doc_ids) for uid, doc_ids in user_document_map.items()}
|
||||||
|
logger.debug(f"集合 {col} 中找到用户和文档映射: {user_document_map}")
|
||||||
|
|
||||||
|
db_types_with_data[db_type] = {
|
||||||
|
"userids": user_document_map
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"可用 db_types 和数据: {db_types_with_data}")
|
||||||
|
return db_types_with_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"列出集合和用户失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def view_collection_details(userid: str) -> None:
|
||||||
|
"""
|
||||||
|
查看特定 userid 在所有集合中的内容和容量,包含 document_id 和元数据。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not userid or "_" in userid:
|
||||||
|
raise ValueError("userid 不能为空且不能包含下划线")
|
||||||
|
|
||||||
|
logger.debug(f"正在查看 userid {userid} 的集合")
|
||||||
|
ensure_milvus_directory()
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collections = utility.list_collections()
|
||||||
|
db_types = [col[len("ragdb_"):] for col in collections if col.startswith("ragdb_")]
|
||||||
|
|
||||||
|
if not db_types:
|
||||||
|
logger.debug(f"未找到任何集合")
|
||||||
|
return
|
||||||
|
|
||||||
|
for db_type in db_types:
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
continue
|
||||||
|
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_pks = collection.query(expr=f"userid == '{userid}'", output_fields=["pk"], limit=10000)
|
||||||
|
num_entities = len(all_pks)
|
||||||
|
results = collection.query(
|
||||||
|
expr=f"userid == '{userid}'",
|
||||||
|
output_fields=["userid","text", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
logger.debug(f"集合 {collection_name} 中 userid {userid} 的文档数: {num_entities}")
|
||||||
|
|
||||||
|
if num_entities == 0:
|
||||||
|
logger.debug(f"集合 {collection_name} 中 userid {userid} 无文档")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"集合 {collection_name} 中 userid {userid} 的内容:")
|
||||||
|
for idx, doc in enumerate(results, 1):
|
||||||
|
metadata = {
|
||||||
|
"userid": doc.get("userid", ""),
|
||||||
|
"filename": doc.get("filename", ""),
|
||||||
|
"file_path": doc.get("file_path", ""),
|
||||||
|
"upload_time": doc.get("upload_time", ""),
|
||||||
|
"file_type": doc.get("file_type", "")
|
||||||
|
}
|
||||||
|
logger.debug(f"文档 {idx}: 内容: {doc.get('text', '')[:200]}..., 元数据: {metadata}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询集合 {collection_name} 的文档失败: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"无法查看 userid {userid} 的集合详情: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def view_vector_data(db_type: str, userid: Optional[str] = None, document_id: Optional[str] = None, limit: int = 100) -> Dict[str, Dict]:
|
||||||
|
"""
|
||||||
|
查看指定 db_type 中的向量数据,可选按 userid 和 document_id 过滤,包含完整元数据和向量。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not db_type or "_" in db_type:
|
||||||
|
raise ValueError("db_type 不能为空且不能包含下划线")
|
||||||
|
if limit <= 0 or limit > 16384:
|
||||||
|
raise ValueError("limit 必须在 1 到 16384 之间")
|
||||||
|
if userid and "_" in userid:
|
||||||
|
raise ValueError("userid 不能包含下划线")
|
||||||
|
if document_id and "_" in document_id:
|
||||||
|
raise ValueError("document_id 不能包含下划线")
|
||||||
|
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"加载集合: {collection_name}")
|
||||||
|
|
||||||
|
expr = []
|
||||||
|
if userid:
|
||||||
|
expr.append(f"userid == '{userid}'")
|
||||||
|
if document_id:
|
||||||
|
expr.append(f"document_id == '{document_id}'")
|
||||||
|
expr = " && ".join(expr) if expr else ""
|
||||||
|
|
||||||
|
results = collection.query(
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["pk", "text", "document_id", "vector", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_data = {}
|
||||||
|
for doc in results:
|
||||||
|
pk = doc.get("pk", str(uuid.uuid4()))
|
||||||
|
text = doc.get("text", "")
|
||||||
|
doc_id = doc.get("document_id", "")
|
||||||
|
vector = doc.get("vector", [])
|
||||||
|
metadata = {
|
||||||
|
"filename": doc.get("filename", ""),
|
||||||
|
"file_path": doc.get("file_path", ""),
|
||||||
|
"upload_time": doc.get("upload_time", ""),
|
||||||
|
"file_type": doc.get("file_type", "")
|
||||||
|
}
|
||||||
|
vector_data[pk] = {
|
||||||
|
"text": text,
|
||||||
|
"vector": vector,
|
||||||
|
"document_id": doc_id,
|
||||||
|
"metadata": metadata
|
||||||
|
}
|
||||||
|
logger.debug(f"pk: {pk}, text: {text[:200]}..., document_id: {doc_id}, vector_length: {len(vector)}, metadata: {metadata}")
|
||||||
|
|
||||||
|
logger.debug(f"共找到 {len(vector_data)} 条向量数据")
|
||||||
|
return vector_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查看向量数据失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def main():
|
||||||
|
userid = "testuser1"
|
||||||
|
db_type = "textdb"
|
||||||
|
|
||||||
|
# logger.info("\n测试 1:带文档初始化")
|
||||||
|
# documents = [
|
||||||
|
# Document(
|
||||||
|
# page_content="深度学习是基于深层神经网络的机器学习子集。",
|
||||||
|
# metadata={
|
||||||
|
# "userid": userid,
|
||||||
|
# "document_id": str(uuid.uuid4()),
|
||||||
|
# "filename": "test_doc1.txt",
|
||||||
|
# "file_path": "/path/to/test_doc1.txt",
|
||||||
|
# "upload_time": datetime.now().isoformat(),
|
||||||
|
# "file_type": "txt"
|
||||||
|
# }
|
||||||
|
# ),
|
||||||
|
# Document(
|
||||||
|
# page_content="知识图谱是一个结构化的语义知识库。",
|
||||||
|
# metadata={
|
||||||
|
# "userid": userid,
|
||||||
|
# "document_id": str(uuid.uuid4()),
|
||||||
|
# "filename": "test_doc2.txt",
|
||||||
|
# "file_path": "/path/to/test_doc2.txt",
|
||||||
|
# "upload_time": datetime.now().isoformat(),
|
||||||
|
# "file_type": "txt"
|
||||||
|
# }
|
||||||
|
# ),
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# vector_store = get_vector_db(userid, db_type, documents=documents)
|
||||||
|
# logger.info(f"集合: ragdb_{db_type}")
|
||||||
|
# logger.info(f"成功为 userid {userid} 在 {db_type} 中插入文档")
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.info("\n测试 2:列出所有 db_types 和文档映射")
|
||||||
|
try:
|
||||||
|
db_types = list_user_collections()
|
||||||
|
logger.info(f"可用 db_types 和文档: {db_types}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.info(f"\n测试 3:查看 userid {userid} 的所有集合")
|
||||||
|
try:
|
||||||
|
view_collection_details(userid)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
# logger.info(f"\n测试 4:查看向量数据")
|
||||||
|
# try:
|
||||||
|
# vector_data = view_vector_data(db_type, userid=userid)
|
||||||
|
# logger.info(f"向量数据: {vector_data}")
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.info(f"\n测试 5:获取 userid {userid} 在{db_type}数据库的文档映射")
|
||||||
|
try:
|
||||||
|
mapping = get_document_mapping(userid, db_type)
|
||||||
|
logger.info(f"文档映射: {mapping}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
build/lib/rag/version.py
Normal file
1
build/lib/rag/version.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.0.1'
|
||||||
BIN
conf/Milvus/milvus.db
Normal file
BIN
conf/Milvus/milvus.db
Normal file
Binary file not shown.
60
conf/config.json
Executable file
60
conf/config.json
Executable file
@ -0,0 +1,60 @@
|
|||||||
|
{
|
||||||
|
"password_key":"!@#$%^&*(*&^%$QWERTYUIqwertyui234567",
|
||||||
|
"logger":{
|
||||||
|
"name":"rag",
|
||||||
|
"levelname":"clientinfo",
|
||||||
|
"logfile":"$[workdir]$/logs/rag.log"
|
||||||
|
},
|
||||||
|
"kdbpath":"$[workdir]$/chroma",
|
||||||
|
"filesroot":"$[workdir]$/files",
|
||||||
|
"website":{
|
||||||
|
"paths":[
|
||||||
|
["$[workdir]$/wwwroot",""]
|
||||||
|
],
|
||||||
|
"client_max_size":20000,
|
||||||
|
"host":"0.0.0.0",
|
||||||
|
"port":9190,
|
||||||
|
"coding":"utf-8",
|
||||||
|
"indexes":[
|
||||||
|
"index.html",
|
||||||
|
"index.tmpl",
|
||||||
|
"index.ui",
|
||||||
|
"index.dspy",
|
||||||
|
"index.md"
|
||||||
|
],
|
||||||
|
"startswiths":[
|
||||||
|
{
|
||||||
|
"leading":"/idfile",
|
||||||
|
"registerfunction":"idFileDownload"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"processors":[
|
||||||
|
[".ws","ws"],
|
||||||
|
[".xterm","xterm"],
|
||||||
|
[".proxy","proxy"],
|
||||||
|
[".llm", "llm"],
|
||||||
|
[".llms", "llms"],
|
||||||
|
[".llma", "llma"],
|
||||||
|
[".xlsxds","xlsxds"],
|
||||||
|
[".sqlds","sqlds"],
|
||||||
|
[".tmpl.js","tmpl"],
|
||||||
|
[".tmpl.css","tmpl"],
|
||||||
|
[".html.tmpl","tmpl"],
|
||||||
|
[".bcrud", "bricks_crud"],
|
||||||
|
[".tmpl","tmpl"],
|
||||||
|
[".app","app"],
|
||||||
|
[".bui","bui"],
|
||||||
|
[".ui","bui"],
|
||||||
|
[".dspy","dspy"],
|
||||||
|
[".md","md"]
|
||||||
|
],
|
||||||
|
"session_max_time":3000,
|
||||||
|
"session_issue_time":2500
|
||||||
|
},
|
||||||
|
"langMapping":{
|
||||||
|
"zh-Hans-CN":"zh-cn",
|
||||||
|
"zh-CN":"zh-cn",
|
||||||
|
"en-us":"en",
|
||||||
|
"en-US":"en"
|
||||||
|
}
|
||||||
|
}
|
||||||
17766
conf/logs/milvus.log
Normal file
17766
conf/logs/milvus.log
Normal file
File diff suppressed because it is too large
Load Diff
8
conf/milvusconfig.yaml
Normal file
8
conf/milvusconfig.yaml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
database:
|
||||||
|
milvus_db_path: /share/wangmeihua/rag/conf/Milvus/milvus.db
|
||||||
|
models:
|
||||||
|
text_embedding_model: /share/models/BAAI/bge-m3
|
||||||
|
logging:
|
||||||
|
name: rag
|
||||||
|
level: DEBUG
|
||||||
|
file: /share/wangmeihua/rag/conf/logs/milvus.log
|
||||||
BIN
data/jishu.pdf
Normal file
BIN
data/jishu.pdf
Normal file
Binary file not shown.
599
data/kg_introduction.txt
Normal file
599
data/kg_introduction.txt
Normal file
File diff suppressed because one or more lines are too long
BIN
data/qianru.pdf
Normal file
BIN
data/qianru.pdf
Normal file
Binary file not shown.
BIN
data/test.docx
Normal file
BIN
data/test.docx
Normal file
Binary file not shown.
1
data/test.txt
Normal file
1
data/test.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
开元云(北京)科技有限公司,是一家注册于2020年的高科技企业,在上海、南京、深圳、济南等地设有分支机构,创始团队核心成员来自一流的云计算公司及电信运营商,拥有云计算、超算、智算和网络运营专业经验,在企业市场均拥有超过十年以上行业经验,服务客户超过2万家。公司以自主研发的业务操作支撑系统(KBoss)为底座,打造开放算力应用服务平台(open-computing),将云计算、算力资源和算力应用进行整合,为高校、科研、大模型、AI等政企客户提供专业算力云服务,形成“云+网+算+应用”的一体化解决方案。在2021年,我们荣幸地成为阿里云计算的合作伙伴,致力于提供算力应用、算力网络、算网一体的产品和服务,同时为芯片、教育科研等企业提供优质的算力服务。2022年,我们与国家超级计算济南中心以及中信网络有限公司签署了战略合作协议,并成功推出了“Kboss”算网平台。在2023年,我们的平台进一步发展,成功引入火山引擎、百度智能云。目前,我们已成为阿里云、江苏未来网络集团的战略合作伙伴。同时,我们深耕“算力+教育”赛道,持续推进高校算力平台项目,积极建设学校算力网络节点,目前已经成功开拓了27所高校。公司提供新一代算力云应用服务模式,通过自主研发的开元算力云应用服务平台,整合算力资源和算法应用,利用创新算力调度化和确定性网络技术,针对现代社会对智能化和数字化需求,形成包括算力云服务、算力网络和算力应用的全场景解决方案。旨在为政府和企业提供"技术+资源+场景+运营”的产业互联网算力云应用服务平台,实现以算力云服务推动数字经济的发展。开元云科技自成立以来得到了包括工信部、教育部、全国高校学会、国家超算中心以及南京未来网络研究院等政府机构、科研机构的大力支持,合作领域包括“东数西算、大科学计算、存算分离、芯算一体及国产工业软件SaaS化”,覆盖人工智能、芯片仿真、生物制药、工业仿真、材料研发、精尖制造、海洋勘探以及气象监测等高科技领域。
|
||||||
BIN
data/zongshu.pdf
Normal file
BIN
data/zongshu.pdf
Normal file
Binary file not shown.
BIN
data/提示学习-王美华.pptx
Normal file
BIN
data/提示学习-王美华.pptx
Normal file
Binary file not shown.
BIN
data/知识图谱构建技术综述(刘峤).pdf
Normal file
BIN
data/知识图谱构建技术综述(刘峤).pdf
Normal file
Binary file not shown.
BIN
data/聚类结果1.xlsx
Normal file
BIN
data/聚类结果1.xlsx
Normal file
Binary file not shown.
BIN
files/0/50/75/76/开题汇报2.pptx
Normal file
BIN
files/0/50/75/76/开题汇报2.pptx
Normal file
Binary file not shown.
BIN
files/102/126/176/41/开题汇报2.pptx
Normal file
BIN
files/102/126/176/41/开题汇报2.pptx
Normal file
Binary file not shown.
BIN
files/115/113/120/34/开题汇报2.pptx
Normal file
BIN
files/115/113/120/34/开题汇报2.pptx
Normal file
Binary file not shown.
BIN
files/23/46/0/52/开题汇报2.pptx
Normal file
BIN
files/23/46/0/52/开题汇报2.pptx
Normal file
Binary file not shown.
BIN
files/42/22/94/46/核心思想.docx
Normal file
BIN
files/42/22/94/46/核心思想.docx
Normal file
Binary file not shown.
BIN
files/8/61/122/40/开题汇报2.pptx
Normal file
BIN
files/8/61/122/40/开题汇报2.pptx
Normal file
Binary file not shown.
BIN
files/87/97/65/20/开题汇报2.pptx
Normal file
BIN
files/87/97/65/20/开题汇报2.pptx
Normal file
Binary file not shown.
21
json/kdb.json
Normal file
21
json/kdb.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"tblname": "permission",
|
||||||
|
"uitype":"tree",
|
||||||
|
"title":"知识库",
|
||||||
|
"params":{
|
||||||
|
"idField":"id",
|
||||||
|
"textField":"name",
|
||||||
|
"sortby":"name",
|
||||||
|
"editable":true,
|
||||||
|
"browserfields":{
|
||||||
|
"alters":{}
|
||||||
|
},
|
||||||
|
"edit_exclouded_fields":[],
|
||||||
|
"parentField":"parentid",
|
||||||
|
"toolbar":{
|
||||||
|
},
|
||||||
|
"binds":[
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
32781
logs/milvus.log
Normal file
32781
logs/milvus.log
Normal file
File diff suppressed because it is too large
Load Diff
14191
logs/rag.log
Normal file
14191
logs/rag.log
Normal file
File diff suppressed because it is too large
Load Diff
0
logs/stderr.log
Normal file
0
logs/stderr.log
Normal file
BIN
models/doc.xlsx
Normal file
BIN
models/doc.xlsx
Normal file
Binary file not shown.
BIN
models/kdb.xlsx
Normal file
BIN
models/kdb.xlsx
Normal file
Binary file not shown.
4
pyproject.toml
Normal file
4
pyproject.toml
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
129
rag.egg-info/PKG-INFO
Normal file
129
rag.egg-info/PKG-INFO
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
Metadata-Version: 2.4
|
||||||
|
Name: rag
|
||||||
|
Version: 0.0.1
|
||||||
|
Summary: rag
|
||||||
|
Home-page: https://github.com/yumoqing/rag
|
||||||
|
Author: yumoqing
|
||||||
|
Author-email: yumoqing@gmail.com
|
||||||
|
Platform: any
|
||||||
|
Classifier: Operating System :: OS Independent
|
||||||
|
Classifier: Programming Language :: Python :: 3
|
||||||
|
Classifier: License :: OSI Approved :: MIT License
|
||||||
|
Description-Content-Type: text/markdown
|
||||||
|
Requires-Dist: chromadb
|
||||||
|
Requires-Dist: langchain
|
||||||
|
Requires-Dist: langchain_community
|
||||||
|
Requires-Dist: unstructured
|
||||||
|
Requires-Dist: langchain-text-splitters
|
||||||
|
Requires-Dist: unstructured[all-docs]
|
||||||
|
Requires-Dist: langchain_milvus
|
||||||
|
Requires-Dist: langchain_huggingface
|
||||||
|
Requires-Dist: transformers
|
||||||
|
Requires-Dist: openai
|
||||||
|
Requires-Dist: torch
|
||||||
|
Requires-Dist: torchvision
|
||||||
|
Requires-Dist: pymilvus
|
||||||
|
Dynamic: author
|
||||||
|
Dynamic: author-email
|
||||||
|
Dynamic: classifier
|
||||||
|
Dynamic: description
|
||||||
|
Dynamic: description-content-type
|
||||||
|
Dynamic: home-page
|
||||||
|
Dynamic: platform
|
||||||
|
Dynamic: requires-dist
|
||||||
|
Dynamic: summary
|
||||||
|
|
||||||
|
# 知识库服务器
|
||||||
|
本系统为不同的客户提供自我管理的知识库,并在知识库基础上提供知识检索
|
||||||
|
|
||||||
|
本系统提供API形式,为注册的服务器提供知识服支持,不面向最终客户
|
||||||
|
|
||||||
|
## 依赖
|
||||||
|
依赖[这些模块](requirements.txt)
|
||||||
|
|
||||||
|
## 安装部署
|
||||||
|
1. 创建rag用户
|
||||||
|
2. 登录rag用户
|
||||||
|
3. 执行以下命令
|
||||||
|
```
|
||||||
|
git clone git@git.kaiyuancloud.cn:yumoqing/rag
|
||||||
|
cd rag/script
|
||||||
|
./install.sh
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
将项目在用户根目录checkout
|
||||||
|
3.
|
||||||
|
## 功能
|
||||||
|
管理client系统的客户知识库,并提供知识查询
|
||||||
|
|
||||||
|
每个客户可以创建一到多个独立的知识库,为不同的业务场景提供知识库知识
|
||||||
|
|
||||||
|
知识库之间数据相互独立,互不干扰。
|
||||||
|
|
||||||
|
## http API
|
||||||
|
|
||||||
|
### add
|
||||||
|
增加知识库文档
|
||||||
|
|
||||||
|
#### path
|
||||||
|
/api/add
|
||||||
|
#### method
|
||||||
|
POST
|
||||||
|
#### 输入
|
||||||
|
name: authentication
|
||||||
|
value: Bears ${apikey}
|
||||||
|
score: headers
|
||||||
|
|
||||||
|
name: file_name
|
||||||
|
value: path of uploaded file
|
||||||
|
score: data
|
||||||
|
|
||||||
|
name: userid
|
||||||
|
value: userid of client system
|
||||||
|
score: data
|
||||||
|
|
||||||
|
name: kdbname
|
||||||
|
value: rag kdb name
|
||||||
|
score: data
|
||||||
|
|
||||||
|
#### 输出
|
||||||
|
|
||||||
|
### query
|
||||||
|
查询知识库
|
||||||
|
|
||||||
|
#### path
|
||||||
|
/api/query
|
||||||
|
#### method
|
||||||
|
POST
|
||||||
|
|
||||||
|
#### 输入
|
||||||
|
name: authentication
|
||||||
|
value: Bears ${apikey}
|
||||||
|
score: headers
|
||||||
|
|
||||||
|
name: prompt
|
||||||
|
value: ${prompt}
|
||||||
|
score: data
|
||||||
|
|
||||||
|
name: userid
|
||||||
|
value: ${userid}
|
||||||
|
score: data
|
||||||
|
|
||||||
|
name: kdbname
|
||||||
|
value: ${kdbname}
|
||||||
|
score: data
|
||||||
|
|
||||||
|
|
||||||
|
#### 输出
|
||||||
|
```
|
||||||
|
{
|
||||||
|
total:返回记录条数,
|
||||||
|
rows:返回记录内容
|
||||||
|
}
|
||||||
|
rows有以下属性
|
||||||
|
content:文本内容
|
||||||
|
distances:距离
|
||||||
|
source:文档path
|
||||||
|
```
|
||||||
|
|
||||||
16
rag.egg-info/SOURCES.txt
Normal file
16
rag.egg-info/SOURCES.txt
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
README.md
|
||||||
|
setup.py
|
||||||
|
rag/__init__.py
|
||||||
|
rag/deletefile.py
|
||||||
|
rag/embed.py
|
||||||
|
rag/init.py
|
||||||
|
rag/kdb.py
|
||||||
|
rag/query.py
|
||||||
|
rag/rag.bak.py
|
||||||
|
rag/vector.py
|
||||||
|
rag/version.py
|
||||||
|
rag.egg-info/PKG-INFO
|
||||||
|
rag.egg-info/SOURCES.txt
|
||||||
|
rag.egg-info/dependency_links.txt
|
||||||
|
rag.egg-info/requires.txt
|
||||||
|
rag.egg-info/top_level.txt
|
||||||
1
rag.egg-info/dependency_links.txt
Normal file
1
rag.egg-info/dependency_links.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
13
rag.egg-info/requires.txt
Normal file
13
rag.egg-info/requires.txt
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
chromadb
|
||||||
|
langchain
|
||||||
|
langchain_community
|
||||||
|
unstructured
|
||||||
|
langchain-text-splitters
|
||||||
|
unstructured[all-docs]
|
||||||
|
langchain_milvus
|
||||||
|
langchain_huggingface
|
||||||
|
transformers
|
||||||
|
openai
|
||||||
|
torch
|
||||||
|
torchvision
|
||||||
|
pymilvus
|
||||||
1
rag.egg-info/top_level.txt
Normal file
1
rag.egg-info/top_level.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
rag
|
||||||
2
rag/__init__.py
Normal file
2
rag/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .version import __version__
|
||||||
|
|
||||||
BIN
rag/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
rag/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/deletefile.cpython-310.pyc
Normal file
BIN
rag/__pycache__/deletefile.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/embed.cpython-310.pyc
Normal file
BIN
rag/__pycache__/embed.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/extract.cpython-310.pyc
Normal file
BIN
rag/__pycache__/extract.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/kgc.cpython-310.pyc
Normal file
BIN
rag/__pycache__/kgc.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/query.cpython-310.pyc
Normal file
BIN
rag/__pycache__/query.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/rag.cpython-310.pyc
Normal file
BIN
rag/__pycache__/rag.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/rerank.cpython-310.pyc
Normal file
BIN
rag/__pycache__/rerank.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/searchquery.cpython-310.pyc
Normal file
BIN
rag/__pycache__/searchquery.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/vector.cpython-310.pyc
Normal file
BIN
rag/__pycache__/vector.cpython-310.pyc
Normal file
Binary file not shown.
BIN
rag/__pycache__/version.cpython-310.pyc
Normal file
BIN
rag/__pycache__/version.cpython-310.pyc
Normal file
Binary file not shown.
290
rag/allfusedsearch.py
Normal file
290
rag/allfusedsearch.py
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from pymilvus import Collection, utility
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from vector import initialize_milvus_connection
|
||||||
|
from searchquery import extract_entities, match_triplets
|
||||||
|
from rerank import rerank_results
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.propagate = False
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
try:
|
||||||
|
with open(config['logging']['file'], 'a', encoding='utf-8') as f:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"日志文件 {config['logging']['file']} 不可写: {str(e)}")
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# 初始化嵌入模型
|
||||||
|
embedding = HuggingFaceEmbeddings(
|
||||||
|
model_name=TEXT_EMBEDDING_MODEL,
|
||||||
|
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
test_vector = embedding.embed_query("test")
|
||||||
|
if len(test_vector) != 1024:
|
||||||
|
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
|
||||||
|
logger.debug("嵌入模型加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
|
||||||
|
# 缓存三元组
|
||||||
|
TRIPLET_CACHE = {}
|
||||||
|
|
||||||
|
|
||||||
|
def load_triplets_to_cache(userid: str, document_id: str) -> List[Dict]:
|
||||||
|
"""加载三元组到缓存"""
|
||||||
|
cache_key = f"{document_id}_{userid}"
|
||||||
|
if cache_key in TRIPLET_CACHE:
|
||||||
|
logger.debug(f"从缓存加载三元组: {cache_key}")
|
||||||
|
return TRIPLET_CACHE[cache_key]
|
||||||
|
|
||||||
|
triplet_file = f"/share/wangmeihua/rag/triples/{document_id}_{userid}.txt"
|
||||||
|
triplets = []
|
||||||
|
try:
|
||||||
|
with open(triplet_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split('\t')
|
||||||
|
if len(parts) < 3:
|
||||||
|
continue
|
||||||
|
head, type_, tail = parts[:3]
|
||||||
|
triplets.append({'head': head, 'type': type_, 'tail': tail})
|
||||||
|
TRIPLET_CACHE[cache_key] = triplets
|
||||||
|
logger.debug(f"加载三元组文件: {triplet_file}, 数量: {len(triplets)}")
|
||||||
|
return triplets
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载三元组失败: {triplet_file}, 错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def fused_search(
|
||||||
|
query: str,
|
||||||
|
userid: str,
|
||||||
|
db_type: str,
|
||||||
|
file_paths: List[str],
|
||||||
|
limit: int = 5,
|
||||||
|
offset: int = 0,
|
||||||
|
use_rerank: bool = True
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
融合 RAG 和三元组召回文本块:
|
||||||
|
- 收集所有输入文件的三元组,拼接为融合文本,向量化后在所有文件中搜索。
|
||||||
|
- 结果去重并按 rerank_score 或 distance 排序,重排序使用融合文本。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query (str): 查询文本
|
||||||
|
userid (str): 用户 ID
|
||||||
|
db_type (str): 数据库类型 (e.g., 'textdb')
|
||||||
|
file_paths (List[str]): 文件路径列表
|
||||||
|
limit (int): 返回结果数量
|
||||||
|
offset (int): 偏移量
|
||||||
|
use_rerank (bool): 是否使用重排序
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[Dict[str, Any]]: 召回结果,包含 text、distance、source、metadata、rerank_score
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 参数验证
|
||||||
|
if not query or not userid or not db_type or not file_paths:
|
||||||
|
raise ValueError("query、userid、db_type 和 file_paths 不能为空")
|
||||||
|
if "_" in userid or "_" in db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
|
|
||||||
|
# 初始化 Milvus 连接
|
||||||
|
connections = initialize_milvus_connection()
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return []
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"加载 Milvus 集合: {collection_name}")
|
||||||
|
|
||||||
|
# 提取实体
|
||||||
|
entity_start = time.time()
|
||||||
|
query_entities = extract_entities(query)
|
||||||
|
logger.debug(f"提取实体: {query_entities}, 耗时: {time.time() - entity_start:.3f}s")
|
||||||
|
|
||||||
|
# 收集所有文件的 document_id 和三元组
|
||||||
|
doc_id_map = {}
|
||||||
|
filenames = []
|
||||||
|
all_triplets = []
|
||||||
|
for file_path in file_paths:
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
filenames.append(filename)
|
||||||
|
logger.debug(f"处理文件: {filename}")
|
||||||
|
|
||||||
|
# 获取 document_id
|
||||||
|
results_query = collection.query(
|
||||||
|
expr=f"userid == '{userid}' and filename == '{filename}'",
|
||||||
|
output_fields=["document_id"],
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
if not results_query:
|
||||||
|
logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档")
|
||||||
|
continue
|
||||||
|
document_id = results_query[0]["document_id"]
|
||||||
|
doc_id_map[filename] = document_id
|
||||||
|
load_triplets_to_cache(userid, document_id)
|
||||||
|
|
||||||
|
# 获取匹配的三元组
|
||||||
|
triplet_start = time.time()
|
||||||
|
matched_triplets = match_triplets(query, query_entities, userid, document_id)
|
||||||
|
logger.debug(
|
||||||
|
f"文件 {filename} 匹配三元组: {len(matched_triplets)} 条, 耗时: {time.time() - triplet_start:.3f}s")
|
||||||
|
all_triplets.extend(matched_triplets)
|
||||||
|
|
||||||
|
if not doc_id_map:
|
||||||
|
logger.warning("未找到任何有效文档")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 拼接融合文本
|
||||||
|
triplet_texts = []
|
||||||
|
for triplet in all_triplets:
|
||||||
|
head = triplet['head']
|
||||||
|
type_ = triplet['type']
|
||||||
|
tail = triplet['tail']
|
||||||
|
if not head or not type_ or not tail:
|
||||||
|
logger.debug(f"无效三元组: {triplet}")
|
||||||
|
continue
|
||||||
|
triplet_texts.append(f"{head} {type_} {tail}")
|
||||||
|
|
||||||
|
# 定义融合文本
|
||||||
|
fused_text = query if not triplet_texts else f"{query} {' '.join(triplet_texts)}"
|
||||||
|
logger.debug(f"融合文本: {fused_text}, 三元组数量: {len(triplet_texts)}")
|
||||||
|
|
||||||
|
# 向量化
|
||||||
|
embed_start = time.time()
|
||||||
|
query_vector = embedding.embed_query(fused_text)
|
||||||
|
query_vector = np.array(query_vector) / np.linalg.norm(query_vector)
|
||||||
|
logger.debug(f"生成融合向量,维度: {len(query_vector)}, 耗时: {time.time() - embed_start:.3f}s")
|
||||||
|
|
||||||
|
# Milvus 搜索
|
||||||
|
expr = f"userid == '{userid}' and filename in {filenames}"
|
||||||
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||||
|
milvus_start = time.time()
|
||||||
|
milvus_results = collection.search(
|
||||||
|
data=[query_vector],
|
||||||
|
anns_field="vector",
|
||||||
|
param=search_params,
|
||||||
|
limit=100,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
logger.debug(f"Milvus 搜索耗时: {time.time() - milvus_start:.3f}s")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for hits in milvus_results:
|
||||||
|
for hit in hits:
|
||||||
|
result = {
|
||||||
|
"text": hit.entity.get("text"),
|
||||||
|
"distance": hit.distance,
|
||||||
|
"source": "fused_query" if not triplet_texts else f"fused_triplets_{len(triplet_texts)}",
|
||||||
|
"metadata": {
|
||||||
|
"userid": hit.entity.get("userid"),
|
||||||
|
"document_id": hit.entity.get("document_id"),
|
||||||
|
"filename": hit.entity.get("filename"),
|
||||||
|
"file_path": hit.entity.get("file_path"),
|
||||||
|
"upload_time": hit.entity.get("upload_time"),
|
||||||
|
"file_type": hit.entity.get("file_type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
logger.debug(
|
||||||
|
f"召回: text={result['text'][:100]}..., distance={result['distance']}, filename={result['metadata']['filename']}")
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
unique_results = []
|
||||||
|
seen_texts = set()
|
||||||
|
for result in results:
|
||||||
|
text = result['text']
|
||||||
|
if not text:
|
||||||
|
logger.warning(f"发现空文本结果: {result['metadata']}")
|
||||||
|
continue
|
||||||
|
if text in seen_texts:
|
||||||
|
logger.debug(f"移除重复文本: text={text[:100]}..., filename={result['metadata']['filename']}")
|
||||||
|
continue
|
||||||
|
seen_texts.add(text)
|
||||||
|
unique_results.append(result)
|
||||||
|
logger.info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(results)})")
|
||||||
|
|
||||||
|
# 可选:重排序
|
||||||
|
if use_rerank and unique_results:
|
||||||
|
logger.debug("开始重排序")
|
||||||
|
logger.debug(f"重排序查询: {fused_text}")
|
||||||
|
rerank_start = time.time()
|
||||||
|
reranked_results = rerank_results(fused_text, unique_results)
|
||||||
|
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||||
|
logger.debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
|
||||||
|
logger.debug(f"重排序耗时: {time.time() - rerank_start:.3f}s")
|
||||||
|
for i, result in enumerate(reranked_results):
|
||||||
|
logger.debug(
|
||||||
|
f"排序结果 {i + 1}: text={result['text'][:100]}..., distance={result['distance']}, rerank_score={result.get('rerank_score', 'N/A')}")
|
||||||
|
logger.info(f"总耗时: {time.time() - start_time:.3f}s")
|
||||||
|
return reranked_results[:limit]
|
||||||
|
|
||||||
|
# 按 distance 降序排序
|
||||||
|
sorted_results = sorted(unique_results, key=lambda x: x['distance'], reverse=True)
|
||||||
|
for i, result in enumerate(sorted_results):
|
||||||
|
logger.debug(f"排序结果 {i + 1}: text={result['text'][:100]}..., distance={result['distance']}")
|
||||||
|
logger.info(f"总耗时: {time.time() - start_time:.3f}s")
|
||||||
|
return sorted_results[:limit]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"融合搜索失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
query = "什么是知识抽取?"
|
||||||
|
userid = "testuser1"
|
||||||
|
db_type = "textdb"
|
||||||
|
file_paths = [
|
||||||
|
"/share/wangmeihua/rag/data/test.docx",
|
||||||
|
"/share/wangmeihua/rag/data/zongshu.pdf",
|
||||||
|
"/share/wangmeihua/rag/data/qianru.pdf",
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
results = fused_search(query, userid, db_type, file_paths, limit=10, offset=0)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f"Result {i + 1}:")
|
||||||
|
print(f"Text: {result['text'][:200]}...")
|
||||||
|
print(f"Distance: {result['distance']:.3f}")
|
||||||
|
print(
|
||||||
|
f"Rerank Score: {result.get('rerank_score', 'N/A') if isinstance(result.get('rerank_score'), (int, float)) else 'N/A':.3f}")
|
||||||
|
print(f"Source: {result['source']}")
|
||||||
|
print(f"Metadata: {result['metadata']}\n")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索失败: {str(e)}")
|
||||||
190
rag/combinedsearch.py
Normal file
190
rag/combinedsearch.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict
|
||||||
|
from pymilvus import connections, Collection, utility
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from query import search_query
|
||||||
|
from searchquery import searchquery
|
||||||
|
from rerank import rerank_results
|
||||||
|
from vector import initialize_milvus_connection, cleanup_milvus_connection
|
||||||
|
import torch
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
logger.handlers.clear() # 清除现有处理器
|
||||||
|
logger.propagate = False # 禁用传播
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# 初始化嵌入模型(缓存)
|
||||||
|
@lru_cache(maxsize=1000)
|
||||||
|
def get_embedding(text: str) -> List[float]:
|
||||||
|
embedding = HuggingFaceEmbeddings(
|
||||||
|
model_name=TEXT_EMBEDDING_MODEL,
|
||||||
|
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
vector = embedding.embed_query(text)
|
||||||
|
if len(vector) != 1024:
|
||||||
|
raise ValueError(f"嵌入模型输出维度 {len(vector)} 不匹配预期 1024")
|
||||||
|
return vector
|
||||||
|
|
||||||
|
def combined_search(query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 10, offset: int = 0) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
结合 RAG 和三元组检索,召回相关文本块,使用 BGE Reranker 重排序。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query (str): 查询文本
|
||||||
|
userid (str): 用户ID
|
||||||
|
db_type (str): 数据库类型
|
||||||
|
file_paths (List[str]): 文档路径列表
|
||||||
|
limit (int): 返回的最大结果数,默认为 10
|
||||||
|
offset (int): 偏移量,默认为 0
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[Dict]: 包含 text、distance、source、metadata 和 rerank_score 的结果列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 参数验证
|
||||||
|
if not query or not userid or not db_type or not file_paths:
|
||||||
|
raise ValueError("query、userid、db_type 和 file_paths 不能为空")
|
||||||
|
if "_" in userid or "_" in db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
|
if len(userid) > 100 or len(db_type) > 100:
|
||||||
|
raise ValueError("userid 或 db_type 的长度超出限制")
|
||||||
|
if limit <= 0 or limit > 16384:
|
||||||
|
raise ValueError("limit 必须在 1 到 16384 之间")
|
||||||
|
if offset < 0:
|
||||||
|
raise ValueError("offset 不能为负数")
|
||||||
|
if limit + offset > 16384:
|
||||||
|
raise ValueError("limit + offset 不能超过 16384")
|
||||||
|
|
||||||
|
for file_path in file_paths:
|
||||||
|
if not isinstance(file_path, str):
|
||||||
|
raise ValueError(f"file_path 必须是字符串: {file_path}")
|
||||||
|
if len(os.path.basename(file_path)) > 255:
|
||||||
|
raise ValueError(f"文件名长度超出 255 个字符: {file_path}")
|
||||||
|
|
||||||
|
# 初始化 Milvus 连接
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# RAG 检索,使用默认 limit=3
|
||||||
|
rag_results = search_query(query, userid, db_type, file_paths, offset=offset)
|
||||||
|
for result in rag_results:
|
||||||
|
result['source'] = 'rag'
|
||||||
|
logger.info(f"RAG 检索返回 {len(rag_results)} 条结果")
|
||||||
|
|
||||||
|
# 三元组检索,使用默认 limit=3
|
||||||
|
triplet_results = searchquery(query, userid, db_type, file_paths, offset=offset)
|
||||||
|
for result in triplet_results:
|
||||||
|
result['source'] = 'triplet'
|
||||||
|
logger.info(f"三元组检索返回 {len(triplet_results)} 条结果")
|
||||||
|
|
||||||
|
# 记录三元组检索结果详情
|
||||||
|
for idx, result in enumerate(triplet_results, 1):
|
||||||
|
logger.debug(f"三元组结果 {idx}: text={result['text'][:200]}..., distance={result['distance']:.4f}, metadata={result['metadata']}")
|
||||||
|
|
||||||
|
# 合并结果
|
||||||
|
all_results = rag_results + triplet_results
|
||||||
|
if not all_results:
|
||||||
|
logger.warning("RAG 和三元组检索均无结果")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 记录合并前的结果
|
||||||
|
logger.debug("合并前结果:")
|
||||||
|
for idx, result in enumerate(all_results, 1):
|
||||||
|
logger.debug(f"结果 {idx} ({result['source']}): text={result['text'][:200]}..., distance={result['distance']:.4f}, metadata={result['metadata']}")
|
||||||
|
|
||||||
|
# 使用 BGE Reranker 重排序
|
||||||
|
reranked_results = rerank_results(query, all_results, top_k=len(all_results))
|
||||||
|
|
||||||
|
# 按 rerank_score 排序(不去重)
|
||||||
|
sorted_results = sorted(reranked_results, key=lambda x: x['rerank_score'], reverse=True)
|
||||||
|
|
||||||
|
# 记录排序后的结果
|
||||||
|
logger.debug("重排序后结果:")
|
||||||
|
for idx, result in enumerate(sorted_results, 1):
|
||||||
|
logger.debug(f"排序结果 {idx} ({result['source']}): text={result['text'][:200]}..., distance={result['distance']:.4f}, rerank_score={result['rerank_score']:.6f}, metadata={result['metadata']}")
|
||||||
|
|
||||||
|
# 去重(基于 text,保留 rerank_score 最大的记录)
|
||||||
|
unique_results = []
|
||||||
|
text_to_result = {}
|
||||||
|
for result in sorted_results:
|
||||||
|
text = result['text']
|
||||||
|
if text not in text_to_result or result['rerank_score'] > text_to_result[text]['rerank_score']:
|
||||||
|
text_to_result[text] = result
|
||||||
|
unique_results = list(text_to_result.values())
|
||||||
|
|
||||||
|
# 记录去重后的结果
|
||||||
|
logger.debug("去重后结果:")
|
||||||
|
for idx, result in enumerate(unique_results, 1):
|
||||||
|
logger.debug(f"去重结果 {idx} ({result['source']}): text={result['text'][:200]}..., distance={result['distance']:.4f}, rerank_score={result['rerank_score']:.6f}, metadata={result['metadata']}")
|
||||||
|
|
||||||
|
# 限制结果数量
|
||||||
|
final_results = unique_results[:limit]
|
||||||
|
logger.info(f"合并后返回 {len(final_results)} 条唯一结果")
|
||||||
|
|
||||||
|
# 移除 weighted_score 字段(若存在),保留 rerank_score 和 source
|
||||||
|
for result in final_results:
|
||||||
|
result.pop('weighted_score', None)
|
||||||
|
|
||||||
|
return final_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"合并搜索失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
cleanup_milvus_connection()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试代码
|
||||||
|
query = "知识图谱构建需要什么技术?"
|
||||||
|
userid = "testuser1"
|
||||||
|
db_type = "textdb"
|
||||||
|
file_paths = [
|
||||||
|
"/share/wangmeihua/rag/data/test.docx",
|
||||||
|
"/share/wangmeihua/rag/data/zongshu.pdf",
|
||||||
|
"/share/wangmeihua/rag/data/qianru.pdf"
|
||||||
|
]
|
||||||
|
limit = 10
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = combined_search(query, userid, db_type, file_paths, limit, offset)
|
||||||
|
print(f"搜索结果 ({len(results)} 条):")
|
||||||
|
for idx, result in enumerate(results, 1):
|
||||||
|
print(f"结果 {idx}:")
|
||||||
|
print(f"内容: {result['text'][:200]}...")
|
||||||
|
print(f"距离: {result['distance']}")
|
||||||
|
print(f"来源: {result['source']}")
|
||||||
|
print(f"重排序分数: {result['rerank_score']}")
|
||||||
|
print(f"元数据: {result['metadata']}")
|
||||||
|
print("-" * 50)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索失败: {str(e)}")
|
||||||
138
rag/deletefile.py
Normal file
138
rag/deletefile.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from pymilvus import connections, Collection, utility
|
||||||
|
from vector import initialize_milvus_connection
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def delete_document(db_type: str, userid: str, filename: str) -> bool:
|
||||||
|
"""
|
||||||
|
根据 db_type、userid 和 filename 删除用户的指定文件数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
db_type (str): 数据库类型(如 'textdb', 'pptdb')
|
||||||
|
userid (str): 用户 ID
|
||||||
|
filename (str): 文件名(如 'test.docx')
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 删除是否成功
|
||||||
|
|
||||||
|
异常:
|
||||||
|
ValueError: 参数无效
|
||||||
|
RuntimeError: 数据库操作失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 参数验证
|
||||||
|
if not db_type or "_" in db_type:
|
||||||
|
raise ValueError("db_type 不能为空且不能包含下划线")
|
||||||
|
if not userid or "_" in userid:
|
||||||
|
raise ValueError("userid 不能为空且不能包含下划线")
|
||||||
|
if not filename:
|
||||||
|
raise ValueError("filename 不能为空")
|
||||||
|
if len(db_type) > 100 or len(userid) > 100 or len(filename) > 255:
|
||||||
|
raise ValueError("db_type、userid 或 filename 的长度超出限制")
|
||||||
|
|
||||||
|
# 初始化 Milvus 连接
|
||||||
|
initialize_milvus_connection()
|
||||||
|
logger.debug(f"已连接到 Milvus Lite,路径: {MILVUS_DB_PATH}")
|
||||||
|
|
||||||
|
# 检查集合是否存在
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 加载集合
|
||||||
|
try:
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"加载集合: {collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||||
|
|
||||||
|
# 查询匹配的 document_id
|
||||||
|
expr = f"userid == '{userid}' and filename == '{filename}'"
|
||||||
|
logger.debug(f"查询表达式: {expr}")
|
||||||
|
try:
|
||||||
|
results = collection.query(
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["document_id"],
|
||||||
|
limit=1000
|
||||||
|
)
|
||||||
|
if not results:
|
||||||
|
logger.warning(f"没有找到 userid={userid}, filename={filename} 的记录")
|
||||||
|
return False
|
||||||
|
document_ids = list(set(result["document_id"] for result in results if "document_id" in result))
|
||||||
|
logger.debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询 document_id 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"查询失败: {str(e)}")
|
||||||
|
|
||||||
|
# 执行删除
|
||||||
|
total_deleted = 0
|
||||||
|
for doc_id in document_ids:
|
||||||
|
try:
|
||||||
|
delete_expr = f"userid == '{userid}' and document_id == '{doc_id}'"
|
||||||
|
logger.debug(f"删除表达式: {delete_expr}")
|
||||||
|
delete_result = collection.delete(delete_expr)
|
||||||
|
deleted_count = delete_result.delete_count
|
||||||
|
total_deleted += deleted_count
|
||||||
|
logger.info(f"成功删除 document_id={doc_id} 的 {deleted_count} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除 document_id={doc_id} 失败: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if total_deleted == 0:
|
||||||
|
logger.warning(f"没有删除任何记录,userid={userid}, filename={filename}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"总计删除 {total_deleted} 条记录,userid={userid}, filename={filename}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"参数验证失败: {str(ve)}")
|
||||||
|
return False
|
||||||
|
except RuntimeError as re:
|
||||||
|
logger.error(f"数据库操作失败: {str(re)}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除文件失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
connections.disconnect("default")
|
||||||
|
logger.debug("已断开 Milvus 连接")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试用例
|
||||||
|
db_type = "textdb"
|
||||||
|
userid = "testuser2"
|
||||||
|
filename = "test.docx"
|
||||||
|
|
||||||
|
logger.info(f"测试:删除 userid={userid}, filename={filename} 的文件")
|
||||||
|
result = delete_document(db_type, userid, filename)
|
||||||
|
print(f"删除结果: {result}")
|
||||||
1
rag/dict/cel.txt
Normal file
1
rag/dict/cel.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
DB2RDF c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
33
rag/dict/concept.txt
Normal file
33
rag/dict/concept.txt
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
上下位关系 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
串联 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
信息抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
共指消解 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
关系抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
分类研究 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
外部知识库 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
实体分类体系 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
实体识别 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
属性 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
属性抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
总结 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
拼图碎片 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
搜索引擎 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
数据层 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
数据层的融合 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
数据挖掘 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
服务器日志 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
模式匹配 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
模式层 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
歧义 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
病毒 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
症状 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
知识图谱 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
知识库的更新 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
算法 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
结构化数据 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
结构化知识库 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
自动化本体构建过程 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
自顶向下 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
资源描述框架 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
阿里巴巴 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
非结构化数据 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
1
rag/dict/date.txt
Normal file
1
rag/dict/date.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
结构化的历史数据 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
1
rag/dict/eve.txt
Normal file
1
rag/dict/eve.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
(Sri) c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
1
rag/dict/loc.txt
Normal file
1
rag/dict/loc.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
城市 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
20
rag/dict/media.txt
Normal file
20
rag/dict/media.txt
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
5 信息抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
万维网 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
信息抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
关系数据库 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
图谱 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
埃博拉病毒的症状有哪些 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
实体消歧 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
推理策略的一环 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
数据层 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
数据驱动的自动化方式 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
智能语义搜索 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
概念层 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
知识图谱 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
知识库 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
结构化数据 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
结构化数据源 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
统计机器学习 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
语料 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
阿里巴巴 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
面向开放域的实体识别 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
5
rag/dict/misc.txt
Normal file
5
rag/dict/misc.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
112种实体类别 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
三元组 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
信息检索 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
实体 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
知识图谱 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
7
rag/dict/org.txt
Normal file
7
rag/dict/org.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
XML c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
微软 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
搜索引擎 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
本体构建本体 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
腾讯 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
谷歌 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
阿里 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
1
rag/dict/per.txt
Normal file
1
rag/dict/per.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
比尔盖茨 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
1
rag/dict/time.txt
Normal file
1
rag/dict/time.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
网的 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
3
rag/dict/unk.txt
Normal file
3
rag/dict/unk.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
Relation Extraction c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
Web c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
|
的知识 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||||
183
rag/embed.py
Normal file
183
rag/embed.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from pymilvus import connections
|
||||||
|
from vector import get_vector_db
|
||||||
|
from filetxt.loader import fileloader
|
||||||
|
from extract import extract_and_save_triplets
|
||||||
|
from kgc import KnowledgeGraph
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.propagate = False
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
def generate_document_id() -> str:
|
||||||
|
"""为文件生成唯一的 document_id"""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
def load_and_split_data(file_path: str, userid: str, document_id: str) -> List[Document]:
|
||||||
|
"""
|
||||||
|
加载文件,分片并生成带有元数据的 Document 对象。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise ValueError(f"文件 {file_path} 不存在")
|
||||||
|
if os.path.getsize(file_path) == 0:
|
||||||
|
raise ValueError(f"文件 {file_path} 为空")
|
||||||
|
logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节")
|
||||||
|
ext = file_path.rsplit('.', 1)[1].lower()
|
||||||
|
logger.debug(f"文件扩展名: {ext}")
|
||||||
|
|
||||||
|
logger.debug("开始加载文件")
|
||||||
|
text = fileloader(file_path)
|
||||||
|
if not text or not text.strip():
|
||||||
|
raise ValueError(f"文件 {file_path} 加载为空")
|
||||||
|
|
||||||
|
document = Document(page_content=text)
|
||||||
|
logger.debug(f"加载完成,生成 1 个文档")
|
||||||
|
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=2000,
|
||||||
|
chunk_overlap=200,
|
||||||
|
length_function=len,
|
||||||
|
)
|
||||||
|
chunks = text_splitter.split_documents([document])
|
||||||
|
logger.debug(f"分割完成,生成 {len(chunks)} 个文档块")
|
||||||
|
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
upload_time = datetime.now().isoformat()
|
||||||
|
documents = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk.metadata.update({
|
||||||
|
'userid': userid,
|
||||||
|
'document_id': document_id,
|
||||||
|
'filename': filename,
|
||||||
|
'file_path': file_path,
|
||||||
|
'upload_time': upload_time,
|
||||||
|
'file_type': ext,
|
||||||
|
'chunk_index': i,
|
||||||
|
'source': file_path,
|
||||||
|
})
|
||||||
|
required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type']
|
||||||
|
if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields):
|
||||||
|
raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}")
|
||||||
|
documents.append(chunk)
|
||||||
|
logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}")
|
||||||
|
|
||||||
|
logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块,document_id: {document_id}")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
raise ValueError(f"加载或分割文件失败: {str(e)}")
|
||||||
|
|
||||||
|
def embed(file_path: str, userid: str, db_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
嵌入文件到 Milvus 向量数据库,抽取三元组保存到指定路径,并将三元组存储到 Neo4j。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not userid or not db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能为空")
|
||||||
|
if "_" in userid:
|
||||||
|
raise ValueError("userid 不能包含下划线")
|
||||||
|
if "_" in db_type:
|
||||||
|
raise ValueError("db_type 不能包含下划线")
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise ValueError(f"文件 {file_path} 不存在")
|
||||||
|
|
||||||
|
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
|
||||||
|
ext = file_path.rsplit('.', 1)[1].lower()
|
||||||
|
if ext not in supported_formats:
|
||||||
|
logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}")
|
||||||
|
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||||
|
|
||||||
|
document_id = generate_document_id()
|
||||||
|
logger.info(f"生成 document_id: {document_id} for file: {file_path}")
|
||||||
|
|
||||||
|
logger.info(f"开始处理文件 {file_path},userid: {userid},db_type: {db_type}")
|
||||||
|
chunks = load_and_split_data(file_path, userid, document_id)
|
||||||
|
if not chunks:
|
||||||
|
logger.error(f"文件 {file_path} 未生成任何文档块")
|
||||||
|
raise ValueError("未生成任何文档块")
|
||||||
|
|
||||||
|
logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块")
|
||||||
|
logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}")
|
||||||
|
|
||||||
|
db = get_vector_db(userid, db_type, documents=chunks)
|
||||||
|
if not db:
|
||||||
|
logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}")
|
||||||
|
raise RuntimeError(f"数据库操作失败")
|
||||||
|
|
||||||
|
try:
|
||||||
|
full_text = fileloader(file_path)
|
||||||
|
if full_text and full_text.strip():
|
||||||
|
success = extract_and_save_triplets(full_text, document_id, userid)
|
||||||
|
triplet_file_path = f"/share/wangmeihua/rag/triples/{document_id}_{userid}.txt"
|
||||||
|
if success and os.path.exists(triplet_file_path):
|
||||||
|
logger.info(f"文件 {file_path} 三元组保存到: {triplet_file_path}")
|
||||||
|
try:
|
||||||
|
kg = KnowledgeGraph(data_path=triplet_file_path, document_id=document_id)
|
||||||
|
logger.info(f"Step 1: 导入图谱节点到 Neo4j,document_id: {document_id}")
|
||||||
|
kg.create_graphnodes()
|
||||||
|
logger.info(f"Step 2: 导入图谱边到 Neo4j,document_id: {document_id}")
|
||||||
|
kg.create_graphrels()
|
||||||
|
logger.info(f"Step 3: 导出 Neo4j 节点数据,document_id: {document_id}")
|
||||||
|
kg.export_data()
|
||||||
|
logger.info(f"文件 {file_path} 三元组成功插入 Neo4j")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"将三元组插入 Neo4j 失败: {str(e)},但不影响 Milvus 嵌入")
|
||||||
|
else:
|
||||||
|
logger.warning(f"文件 {file_path} 的三元组抽取失败或文件不存在: {triplet_file_path}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"文件 {file_path} 内容为空,无法抽取三元组")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文件 {file_path} 三元组抽取失败: {str(e)},但不影响向量化")
|
||||||
|
|
||||||
|
logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}")
|
||||||
|
return False
|
||||||
|
except RuntimeError as re:
|
||||||
|
logger.error(f"嵌入文件 {file_path} 失败: {str(re)}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入文件 {file_path} 失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_file = "/share/wangmeihua/rag/data/test.docx"
|
||||||
|
userid = "testuser1"
|
||||||
|
db_type = "textdb"
|
||||||
|
result = embed(test_file, userid, db_type)
|
||||||
|
print(f"嵌入结果: {result}")
|
||||||
225
rag/extract.py
Normal file
225
rag/extract.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import re
|
||||||
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
# 三元组保存路径
|
||||||
|
TRIPLES_OUTPUT_DIR = "/share/wangmeihua/rag/triples"
|
||||||
|
os.makedirs(TRIPLES_OUTPUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 加载 mREBEL 模型和分词器
|
||||||
|
local_path = "/share/models/Babelscape/mrebel-large"
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(local_path, src_lang="zh_CN", tgt_lang="tp_XX")
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(local_path)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
triplet_id = tokenizer.convert_tokens_to_ids("<triplet>")
|
||||||
|
logger.debug(f"成功加载 mREBEL 模型,分词器 triplet_id: {triplet_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载 mREBEL 模型失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载 mREBEL 模型失败: {str(e)}")
|
||||||
|
|
||||||
|
# 优化生成参数
|
||||||
|
gen_kwargs = {
|
||||||
|
"max_length": 512,
|
||||||
|
"min_length": 10,
|
||||||
|
"length_penalty": 0.5,
|
||||||
|
"num_beams": 3,
|
||||||
|
"num_return_sequences": 1,
|
||||||
|
"no_repeat_ngram_size": 2,
|
||||||
|
"early_stopping": True,
|
||||||
|
"decoder_start_token_id": triplet_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def split_document(text: str, max_chunk_size: int = 150) -> list:
|
||||||
|
"""分割文档为语义完整的块"""
|
||||||
|
sentences = re.split(r'(?<=[。!?;\n])', text)
|
||||||
|
chunks = []
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
if len(current_chunk) + len(sentence) <= max_chunk_size:
|
||||||
|
current_chunk += sentence
|
||||||
|
else:
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
current_chunk = sentence
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def extract_triplets_typed(text: str) -> list:
|
||||||
|
"""解析 mREBEL 生成文本,匹配 <triplet> <entity1> <type1> <entity2> <type2> <relation> 格式"""
|
||||||
|
triplets = []
|
||||||
|
logger.debug(f"原始生成文本: {text}")
|
||||||
|
|
||||||
|
# 分割标记
|
||||||
|
tokens = []
|
||||||
|
in_tag = False
|
||||||
|
buffer = ""
|
||||||
|
for char in text:
|
||||||
|
if char == '<':
|
||||||
|
in_tag = True
|
||||||
|
if buffer:
|
||||||
|
tokens.append(buffer.strip())
|
||||||
|
buffer = ""
|
||||||
|
buffer += char
|
||||||
|
elif char == '>':
|
||||||
|
in_tag = False
|
||||||
|
buffer += char
|
||||||
|
tokens.append(buffer.strip())
|
||||||
|
buffer = ""
|
||||||
|
else:
|
||||||
|
buffer += char
|
||||||
|
if buffer:
|
||||||
|
tokens.append(buffer.strip())
|
||||||
|
|
||||||
|
# 过滤特殊标记
|
||||||
|
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
|
||||||
|
tokens = [t for t in tokens if t not in special_tokens and t]
|
||||||
|
|
||||||
|
logger.debug(f"处理后标记: {tokens}")
|
||||||
|
|
||||||
|
# 解析三元组
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
if tokens[i] == "<triplet>" and i + 5 < len(tokens):
|
||||||
|
entity1 = tokens[i + 1]
|
||||||
|
type1 = tokens[i + 2][1:-1] if tokens[i + 2].startswith("<") and tokens[i + 2].endswith(">") else ""
|
||||||
|
entity2 = tokens[i + 3]
|
||||||
|
type2 = tokens[i + 4][1:-1] if tokens[i + 4].startswith("<") and tokens[i + 4].endswith(">") else ""
|
||||||
|
relation = tokens[i + 5]
|
||||||
|
|
||||||
|
if entity1 and type1 and entity2 and type2 and relation:
|
||||||
|
triplets.append({
|
||||||
|
'head': entity1.strip(),
|
||||||
|
'head_type': type1,
|
||||||
|
'type': relation.strip(),
|
||||||
|
'tail': entity2.strip(),
|
||||||
|
'tail_type': type2
|
||||||
|
})
|
||||||
|
logger.debug(f"添加三元组: {entity1}({type1}) - {relation} - {entity2}({type2})")
|
||||||
|
i += 6
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return triplets
|
||||||
|
|
||||||
|
def extract_and_save_triplets(text: str, document_id: str, userid: str) -> bool:
|
||||||
|
"""
|
||||||
|
从文本中抽取三元组并保存到指定路径。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
text (str): 输入文本
|
||||||
|
document_id (str): 文档ID
|
||||||
|
userid (str): 用户ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 三元组抽取和保存是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not text or not document_id or not userid:
|
||||||
|
raise ValueError("text、document_id 和 userid 不能为空")
|
||||||
|
if "_" in document_id or "_" in userid:
|
||||||
|
raise ValueError("document_id 和 userid 不能包含下划线")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
logger.info(f"开始抽取文档 {document_id} 的三元组,userid: {userid}")
|
||||||
|
|
||||||
|
# 分割文本为语义块
|
||||||
|
text_chunks = split_document(text, max_chunk_size=150)
|
||||||
|
logger.debug(f"分割为 {len(text_chunks)} 个文本块")
|
||||||
|
|
||||||
|
# 处理所有文本块
|
||||||
|
all_triplets = []
|
||||||
|
for i, chunk in enumerate(text_chunks):
|
||||||
|
logger.debug(f"处理块 {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
|
||||||
|
|
||||||
|
# 分词
|
||||||
|
model_inputs = tokenizer(
|
||||||
|
chunk,
|
||||||
|
max_length=256,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt"
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# 生成
|
||||||
|
try:
|
||||||
|
generated_tokens = model.generate(
|
||||||
|
model_inputs["input_ids"],
|
||||||
|
attention_mask=model_inputs["attention_mask"],
|
||||||
|
**gen_kwargs,
|
||||||
|
)
|
||||||
|
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
|
||||||
|
for idx, sentence in enumerate(decoded_preds):
|
||||||
|
logger.debug(f"块 {i + 1} 生成文本: {sentence}")
|
||||||
|
triplets = extract_triplets_typed(sentence)
|
||||||
|
if triplets:
|
||||||
|
logger.debug(f"块 {i + 1} 提取到 {len(triplets)} 个三元组")
|
||||||
|
all_triplets.extend(triplets)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"处理块 {i + 1} 时出错: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
unique_triplets = []
|
||||||
|
seen = set()
|
||||||
|
for t in all_triplets:
|
||||||
|
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
|
||||||
|
if identifier not in seen:
|
||||||
|
seen.add(identifier)
|
||||||
|
unique_triplets.append(t)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
output_file = os.path.join(TRIPLES_OUTPUT_DIR, f"{document_id}_{userid}.txt")
|
||||||
|
try:
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
for t in unique_triplets:
|
||||||
|
f.write(f"{t['head']}\t{t['type']}\t{t['tail']}\t{t['head_type']}\t{t['tail_type']}\n")
|
||||||
|
logger.info(f"文档 {document_id} 的 {len(unique_triplets)} 个三元组已保存到: {output_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存文档 {document_id} 的三元组失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"文档 {document_id} 三元组抽取完成,耗时: {end_time - start_time:.2f} 秒")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"抽取或保存三元组失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试用例
|
||||||
|
test_text = "知识图谱是一个结构化的语义知识库。深度学习是基于深层神经网络的机器学习子集。"
|
||||||
|
document_id = "testdoc123"
|
||||||
|
userid = "testuser1"
|
||||||
|
result = extract_and_save_triplets(test_text, document_id, userid)
|
||||||
|
print(f"抽取结果: {result}")
|
||||||
290
rag/fusedsearch.py
Normal file
290
rag/fusedsearch.py
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from pymilvus import Collection, utility
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from vector import initialize_milvus_connection
|
||||||
|
from searchquery import extract_entities, match_triplets
|
||||||
|
from rerank import rerank_results
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.propagate = False
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
try:
|
||||||
|
with open(config['logging']['file'], 'a', encoding='utf-8') as f:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"日志文件 {config['logging']['file']} 不可写: {str(e)}")
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# 初始化嵌入模型
|
||||||
|
embedding = HuggingFaceEmbeddings(
|
||||||
|
model_name=TEXT_EMBEDDING_MODEL,
|
||||||
|
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
test_vector = embedding.embed_query("test")
|
||||||
|
if len(test_vector) != 1024:
|
||||||
|
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
|
||||||
|
logger.debug("嵌入模型加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
|
||||||
|
def fused_search(
|
||||||
|
query: str,
|
||||||
|
userid: str,
|
||||||
|
db_type: str,
|
||||||
|
file_paths: List[str],
|
||||||
|
limit: int = 10,
|
||||||
|
offset: int = 0,
|
||||||
|
use_rerank: bool = True
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
融合 RAG 和三元组召回文本块:
|
||||||
|
- 调用 searchquery.py 的 extract_entities 和 match_triplets 获取三元组。
|
||||||
|
- 将所有匹配三元组拼接为融合文本,向量化后在 Milvus 中搜索。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query (str): 查询文本
|
||||||
|
userid (str): 用户 ID
|
||||||
|
db_type (str): 数据库类型 (e.g., 'textdb')
|
||||||
|
file_paths (List[str]): 文件路径列表
|
||||||
|
limit (int): 返回结果数量
|
||||||
|
offset (int): 偏移量
|
||||||
|
use_rerank (bool): 是否使用重排序
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[Dict[str, Any]]: 召回结果,包含 text、distance、metadata
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}")
|
||||||
|
|
||||||
|
# 参数验证
|
||||||
|
if not query or not userid or not db_type or not file_paths:
|
||||||
|
raise ValueError("query、userid、db_type 和 file_paths 不能为空")
|
||||||
|
if "_" in userid or "_" in db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
|
|
||||||
|
# 初始化 Milvus 连接
|
||||||
|
connections = initialize_milvus_connection()
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return []
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"加载 Milvus 集合: {collection_name}")
|
||||||
|
|
||||||
|
# 提取实体
|
||||||
|
query_entities = extract_entities(query)
|
||||||
|
logger.debug(f"提取实体: {query_entities}")
|
||||||
|
|
||||||
|
# 收集所有结果
|
||||||
|
results = []
|
||||||
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||||
|
|
||||||
|
for file_path in file_paths:
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
logger.debug(f"处理文件: {filename}")
|
||||||
|
|
||||||
|
# 获取 document_id
|
||||||
|
results_query = collection.query(
|
||||||
|
expr=f"userid == '{userid}' and filename == '{filename}'",
|
||||||
|
output_fields=["document_id"],
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
if not results_query:
|
||||||
|
logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档")
|
||||||
|
continue
|
||||||
|
document_id = results_query[0]["document_id"]
|
||||||
|
logger.debug(f"找到 document_id: {document_id}")
|
||||||
|
|
||||||
|
# 获取匹配的三元组
|
||||||
|
matched_triplets = match_triplets(query, query_entities, userid, document_id)
|
||||||
|
logger.debug(f"匹配三元组: {matched_triplets}")
|
||||||
|
|
||||||
|
# 若无三元组,使用原查询向量化
|
||||||
|
if not matched_triplets:
|
||||||
|
logger.debug(f"无匹配三元组,使用原查询: {query}")
|
||||||
|
query_vector = embedding.embed_query(query)
|
||||||
|
expr = f"userid == '{userid}' and filename == '{filename}'"
|
||||||
|
milvus_results = collection.search(
|
||||||
|
data=[query_vector],
|
||||||
|
anns_field="vector",
|
||||||
|
param=search_params,
|
||||||
|
limit=limit,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
for hits in milvus_results:
|
||||||
|
for hit in hits:
|
||||||
|
result = {
|
||||||
|
"text": hit.entity.get("text"),
|
||||||
|
"distance": hit.distance,
|
||||||
|
"source": "fused_query",
|
||||||
|
"metadata": {
|
||||||
|
"userid": hit.entity.get("userid"),
|
||||||
|
"document_id": hit.entity.get("document_id"),
|
||||||
|
"filename": hit.entity.get("filename"),
|
||||||
|
"file_path": hit.entity.get("file_path"),
|
||||||
|
"upload_time": hit.entity.get("upload_time"),
|
||||||
|
"file_type": hit.entity.get("file_type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
logger.debug(f"召回: text={result['text'][:100]}..., distance={result['distance']}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 拼接所有三元组
|
||||||
|
triplet_texts = []
|
||||||
|
for triplet in matched_triplets:
|
||||||
|
head = triplet['head']
|
||||||
|
type = triplet['type']
|
||||||
|
tail = triplet['tail']
|
||||||
|
if not head or not type or not tail:
|
||||||
|
logger.debug(f"无效三元组: {triplet}")
|
||||||
|
continue
|
||||||
|
triplet_texts.append(f"{head} {type} {tail}")
|
||||||
|
if not triplet_texts:
|
||||||
|
logger.debug(f"无有效三元组,使用原查询: {query}")
|
||||||
|
query_vector = embedding.embed_query(query)
|
||||||
|
expr = f"userid == '{userid}' and filename == '{filename}'"
|
||||||
|
milvus_results = collection.search(
|
||||||
|
data=[query_vector],
|
||||||
|
anns_field="vector",
|
||||||
|
param=search_params,
|
||||||
|
limit=5,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
for hits in milvus_results:
|
||||||
|
for hit in hits:
|
||||||
|
result = {
|
||||||
|
"text": hit.entity.get("text"),
|
||||||
|
"distance": hit.distance,
|
||||||
|
"source": "fused_query",
|
||||||
|
"metadata": {
|
||||||
|
"userid": hit.entity.get("userid"),
|
||||||
|
"document_id": hit.entity.get("document_id"),
|
||||||
|
"filename": hit.entity.get("filename"),
|
||||||
|
"file_path": hit.entity.get("file_path"),
|
||||||
|
"upload_time": hit.entity.get("upload_time"),
|
||||||
|
"file_type": hit.entity.get("file_type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
logger.debug(f"召回: text={result['text'][:100]}..., distance={result['distance']}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 生成融合文本
|
||||||
|
fused_text = f"{query} {' '.join(triplet_texts)}"
|
||||||
|
logger.debug(f"融合文本: {fused_text}")
|
||||||
|
|
||||||
|
# 向量化
|
||||||
|
fused_vector = embedding.embed_query(fused_text)
|
||||||
|
fused_vector = np.array(fused_vector) / np.linalg.norm(fused_vector)
|
||||||
|
logger.debug(f"生成融合向量,维度: {len(fused_vector)}")
|
||||||
|
|
||||||
|
# Milvus 搜索
|
||||||
|
expr = f"userid == '{userid}' and filename == '{filename}'"
|
||||||
|
milvus_results = collection.search(
|
||||||
|
data=[fused_vector],
|
||||||
|
anns_field="vector",
|
||||||
|
param=search_params,
|
||||||
|
limit=5,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
|
||||||
|
for hits in milvus_results:
|
||||||
|
for hit in hits:
|
||||||
|
result = {
|
||||||
|
"text": hit.entity.get("text"),
|
||||||
|
"distance": hit.distance,
|
||||||
|
"source": f"fused_triplets_{len(triplet_texts)}",
|
||||||
|
"metadata": {
|
||||||
|
"userid": hit.entity.get("userid"),
|
||||||
|
"document_id": hit.entity.get("document_id"),
|
||||||
|
"filename": hit.entity.get("filename"),
|
||||||
|
"file_path": hit.entity.get("file_path"),
|
||||||
|
"upload_time": hit.entity.get("upload_time"),
|
||||||
|
"file_type": hit.entity.get("file_type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
logger.debug(f"召回: text={result['text'][:100]}..., distance={result['distance']}")
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
unique_results = []
|
||||||
|
seen_texts = set()
|
||||||
|
for result in results:
|
||||||
|
text = result['text']
|
||||||
|
if text not in seen_texts:
|
||||||
|
seen_texts.add(text)
|
||||||
|
unique_results.append(result)
|
||||||
|
logger.debug(f"去重后结果数量: {len(unique_results)}")
|
||||||
|
|
||||||
|
# 可选:重排序
|
||||||
|
if use_rerank and unique_results:
|
||||||
|
logger.debug("开始重排序")
|
||||||
|
reranked_results = rerank_results(query, unique_results)
|
||||||
|
# 按 rerank_score 降序排序
|
||||||
|
reranked_results = sorted(reranked_results, key=lambda x: x['rerank_score'], reverse=True)
|
||||||
|
for i, result in enumerate(reranked_results):
|
||||||
|
logger.debug(f"排序结果 {i+1}: text={result['text'][:100]}..., distance={result['distance']}, rerank_score={result['rerank_score']}")
|
||||||
|
return reranked_results[:limit]
|
||||||
|
|
||||||
|
# 按 distance 降序排序
|
||||||
|
sorted_results = sorted(unique_results, key=lambda x: x['distance'], reverse=True)
|
||||||
|
for i, result in enumerate(sorted_results):
|
||||||
|
logger.debug(f"排序结果 {i+1}: text={result['text'][:100]}..., distance={result['distance']}")
|
||||||
|
return sorted_results[:limit]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"融合搜索失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
return []
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
query = "知识图谱构建需要什么技术?"
|
||||||
|
userid = "testuser1"
|
||||||
|
db_type = "textdb"
|
||||||
|
file_paths = [
|
||||||
|
"/share/wangmeihua/rag/data/test.docx",
|
||||||
|
"/share/wangmeihua/rag/data/zongshu.pdf",
|
||||||
|
"/share/wangmeihua/rag/data/qianru.pdf"
|
||||||
|
]
|
||||||
|
results = fused_search(query, userid, db_type, file_paths, limit=10, offset=0)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f"Result {i+1}:")
|
||||||
|
print(f"Text: {result['text'][:200]}...")
|
||||||
|
print(f"Distance: {result['distance']}")
|
||||||
|
print(f"Source: {result['source']}")
|
||||||
|
print(f"Metadata: {result['metadata']}\n")
|
||||||
15
rag/init.py
Normal file
15
rag/init.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from appPublic.worker import awaitify
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from .kdb import add_kdb, add_dir, add_doc, get_all_docs
|
||||||
|
from .query import search_query
|
||||||
|
from .embed import embed
|
||||||
|
def load_rag():
|
||||||
|
env = ServerEnv()
|
||||||
|
env.add_kdb = add_kdb
|
||||||
|
env.query = awaitify(search_query)
|
||||||
|
env.embed = awaitify(embed)
|
||||||
|
env.add_dir = add_dir
|
||||||
|
env.add_doc = add_doc
|
||||||
|
env.get_all_docs = get_all_docs
|
||||||
|
|
||||||
|
|
||||||
81
rag/kdb.py
Normal file
81
rag/kdb.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
|
||||||
|
from traceback import format_exc
|
||||||
|
from appPublic.uniqueID import getID
|
||||||
|
from appPublic.timeUtils import curDateString
|
||||||
|
from appPublic.dictObject import DictObject
|
||||||
|
from sqlor.dbpools import DBPools
|
||||||
|
from ahserver.serverenv import get_serverenv
|
||||||
|
from ahserver.filestorage import FileStorage
|
||||||
|
|
||||||
|
async def add_kdb(kdb:dict) -> None:
|
||||||
|
"""
|
||||||
|
添加知识库
|
||||||
|
"""
|
||||||
|
kdb = DictObject(**kdb)
|
||||||
|
kdb.parentid=None
|
||||||
|
if kdb.id is None:
|
||||||
|
kdb.id = getID()
|
||||||
|
kdb.entity_type = '0'
|
||||||
|
kdb.create_date = curDateString()
|
||||||
|
if kdb.orgid is None:
|
||||||
|
e = Exception(f'Can not add none orgid kdb')
|
||||||
|
exception(f'{e}\n{format_exc()}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
f = get_serverenv('get_module_dbname')
|
||||||
|
dbname = f('rag')
|
||||||
|
db = DBPools()
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
await C('kdb', kdb.copy())
|
||||||
|
|
||||||
|
async def add_dir(kdb:dict) -> None:
|
||||||
|
"""
|
||||||
|
添加子目录
|
||||||
|
"""
|
||||||
|
kdb = DictObject(**kdb)
|
||||||
|
if kdb.parentid is None:
|
||||||
|
e = Exception(f'Can not add root folder')
|
||||||
|
exception(f'{e}\n{format_exc()}')
|
||||||
|
raise e
|
||||||
|
if kdb.id is None:
|
||||||
|
kdb.id = getID()
|
||||||
|
kdb.entity_type = '1'
|
||||||
|
kdb.create_date = curDateString()
|
||||||
|
f = get_serverenv('get_module_dbname')
|
||||||
|
dbname = f('rag')
|
||||||
|
db = DBPools()
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
await C('kdb', kdb.copy())
|
||||||
|
|
||||||
|
async def add_doc(doc:dict) -> None:
|
||||||
|
"""
|
||||||
|
添加文档
|
||||||
|
"""
|
||||||
|
doc = DictObject(**doc)
|
||||||
|
if doc.parentid is None:
|
||||||
|
e = Exception(f'Can not add root document')
|
||||||
|
exception(f'{e}\n{format_exc()}')
|
||||||
|
raise e
|
||||||
|
if doc.id is None:
|
||||||
|
doc.id = getID()
|
||||||
|
fs = FileStorage()
|
||||||
|
doc.realpath = fs.realPath(doc.webpath)
|
||||||
|
doc.create_date = curDateString()
|
||||||
|
f = get_serverenv('get_module_dbname')
|
||||||
|
dbname = f('rag')
|
||||||
|
db = DBPools()
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
await C('doc', doc.copy())
|
||||||
|
|
||||||
|
async def get_all_docs(sor, kdbid):
|
||||||
|
"""
|
||||||
|
获取所有kdbid下的文档,含子目录的
|
||||||
|
"""
|
||||||
|
docs = await sor.R('doc', {'parentid':kdbid})
|
||||||
|
kdbs = await sor.R('kdb', {'parentid':kdbid})
|
||||||
|
for kdb in kdbs:
|
||||||
|
docs1 = await get_all_docs(kdb.id)
|
||||||
|
docs += docs1
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
194
rag/kgc.py
Normal file
194
rag/kgc.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from py2neo import Graph, Node, Relationship
|
||||||
|
from typing import Set, List, Dict, Tuple
|
||||||
|
|
||||||
|
from ufw.common import share_dir
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class KnowledgeGraph:
|
||||||
|
def __init__(self, data_path: str, document_id: str = None):
|
||||||
|
self.data_path = data_path
|
||||||
|
self.document_id = document_id or os.path.basename(data_path).split('_')[0]
|
||||||
|
self.g = Graph("bolt://10.18.34.18:7687", auth=('neo4j', '261229..wmh'))
|
||||||
|
logger.info(f"开始构建知识图谱,data_path: {self.data_path}, document_id: {self.document_id}")
|
||||||
|
# 验证 data_path 是否有效
|
||||||
|
if not os.path.exists(self.data_path):
|
||||||
|
logger.error(f"数据路径 {self.data_path} 不存在")
|
||||||
|
raise ValueError(f"数据路径 {self.data_path} 不存在")
|
||||||
|
|
||||||
|
def _normalize_label(self, entity_type: str) -> str:
|
||||||
|
"""规范化实体类型为 Neo4j 标签"""
|
||||||
|
if not entity_type or not entity_type.strip():
|
||||||
|
return 'Entity'
|
||||||
|
entity_type = re.sub(r'[^\w\s]', '', entity_type.strip())
|
||||||
|
words = entity_type.split()
|
||||||
|
label = '_'.join(word.capitalize() for word in words if word)
|
||||||
|
return label or 'Entity'
|
||||||
|
|
||||||
|
def _clean_relation(self, relation: str) -> Tuple[str, str]:
|
||||||
|
"""清洗关系,返回 (rel_type, rel_name)"""
|
||||||
|
relation = relation.strip()
|
||||||
|
if not relation:
|
||||||
|
return 'RELATED_TO', '相关'
|
||||||
|
if relation.startswith('<') and relation.endswith('>'):
|
||||||
|
cleaned_relation = relation[1:-1]
|
||||||
|
rel_name = cleaned_relation
|
||||||
|
rel_type = re.sub(r'[^\w\s]', '', cleaned_relation).replace(' ', '_').upper()
|
||||||
|
else:
|
||||||
|
rel_name = relation
|
||||||
|
rel_type = re.sub(r'[^\w\s]', '', relation).replace(' ', '_').upper()
|
||||||
|
if 'instance of' in relation.lower():
|
||||||
|
rel_type = 'INSTANCE_OF'
|
||||||
|
rel_name = '实例'
|
||||||
|
elif 'subclass of' in relation.lower():
|
||||||
|
rel_type = 'SUBCLASS_OF'
|
||||||
|
rel_name = '子类'
|
||||||
|
elif 'part of' in relation.lower():
|
||||||
|
rel_type = 'PART_OF'
|
||||||
|
rel_name = '部分'
|
||||||
|
logger.debug(f"处理关系: {relation} -> {rel_type} ({rel_name})")
|
||||||
|
return rel_type, rel_name
|
||||||
|
|
||||||
|
def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]:
|
||||||
|
"""读取三元组数据,返回节点和关系"""
|
||||||
|
nodes_by_label = {}
|
||||||
|
relations_by_type = {}
|
||||||
|
triples = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug(f"尝试读取文件: {self.data_path}")
|
||||||
|
with open(self.data_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith('#'):
|
||||||
|
continue
|
||||||
|
parts = line.split('\t')
|
||||||
|
if len(parts) != 5:
|
||||||
|
logger.warning(f"无效行: {line}")
|
||||||
|
continue
|
||||||
|
head, relation, tail, head_type, tail_type = parts
|
||||||
|
head_label = self._normalize_label(head_type)
|
||||||
|
tail_label = self._normalize_label(tail_type)
|
||||||
|
logger.debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}")
|
||||||
|
|
||||||
|
if head_label not in nodes_by_label:
|
||||||
|
nodes_by_label[head_label] = set()
|
||||||
|
if tail_label not in nodes_by_label:
|
||||||
|
nodes_by_label[tail_label] = set()
|
||||||
|
nodes_by_label[head_label].add(head)
|
||||||
|
nodes_by_label[tail_label].add(tail)
|
||||||
|
|
||||||
|
rel_type, rel_name = self._clean_relation(relation)
|
||||||
|
if rel_type not in relations_by_type:
|
||||||
|
relations_by_type[rel_type] = []
|
||||||
|
relations_by_type[rel_type].append({
|
||||||
|
'head': head,
|
||||||
|
'tail': tail,
|
||||||
|
'head_label': head_label,
|
||||||
|
'tail_label': tail_label,
|
||||||
|
'rel_name': rel_name
|
||||||
|
})
|
||||||
|
|
||||||
|
triples.append({
|
||||||
|
'head': head,
|
||||||
|
'relation': relation,
|
||||||
|
'tail': tail,
|
||||||
|
'head_type': head_type,
|
||||||
|
'tail_type': tail_type
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())} 个")
|
||||||
|
logger.info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())} 条")
|
||||||
|
return nodes_by_label, relations_by_type, triples
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取数据失败: {str(e)},data_path: {self.data_path}")
|
||||||
|
raise RuntimeError(f"读取数据失败: {str(e)}")
|
||||||
|
|
||||||
|
def create_node(self, label: str, nodes: Set[str]):
|
||||||
|
"""创建节点,包含 document_id 属性"""
|
||||||
|
count = 0
|
||||||
|
for node_name in nodes:
|
||||||
|
query = f"MATCH (n:{label} {{name: '{node_name}', document_id: '{self.document_id}'}}) RETURN n"
|
||||||
|
try:
|
||||||
|
if self.g.run(query).data():
|
||||||
|
continue
|
||||||
|
node = Node(label, name=node_name, document_id=self.document_id)
|
||||||
|
self.g.create(node)
|
||||||
|
count += 1
|
||||||
|
logger.debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}")
|
||||||
|
logger.info(f"创建 {label} 节点: {count}/{len(nodes)} 个")
|
||||||
|
return count
|
||||||
|
|
||||||
|
def create_relationship(self, rel_type: str, relations: List[Dict]):
|
||||||
|
"""创建关系"""
|
||||||
|
count = 0
|
||||||
|
total = len(relations)
|
||||||
|
seen_edges = set()
|
||||||
|
for rel in relations:
|
||||||
|
head, tail, head_label, tail_label, rel_name = (
|
||||||
|
rel['head'], rel['tail'], rel['head_label'], rel['tail_label'], rel['rel_name']
|
||||||
|
)
|
||||||
|
edge_key = f"{head_label}:{head}###{tail_label}:{tail}###{rel_type}"
|
||||||
|
if edge_key in seen_edges:
|
||||||
|
continue
|
||||||
|
seen_edges.add(edge_key)
|
||||||
|
|
||||||
|
query = (
|
||||||
|
f"MATCH (p:{head_label} {{name: '{head}', document_id: '{self.document_id}'}}), "
|
||||||
|
f"(q:{tail_label} {{name: '{tail}', document_id: '{self.document_id}'}}) "
|
||||||
|
f"CREATE (p)-[r:{rel_type} {{name: '{rel_name}'}}]->(q)"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.g.run(query)
|
||||||
|
count += 1
|
||||||
|
logger.debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建关系失败: {query}, 错误: {str(e)}")
|
||||||
|
logger.info(f"创建 {rel_type} 关系: {count}/{total} 条")
|
||||||
|
return count
|
||||||
|
|
||||||
|
def create_graphnodes(self):
|
||||||
|
"""创建所有节点"""
|
||||||
|
nodes_by_label, _, _ = self.read_nodes()
|
||||||
|
total = 0
|
||||||
|
for label, nodes in nodes_by_label.items():
|
||||||
|
total += self.create_node(label, nodes)
|
||||||
|
logger.info(f"总计创建节点: {total} 个")
|
||||||
|
return total
|
||||||
|
|
||||||
|
def create_graphrels(self):
|
||||||
|
"""创建所有关系"""
|
||||||
|
_, relations_by_type, _ = self.read_nodes()
|
||||||
|
total = 0
|
||||||
|
for rel_type, relations in relations_by_type.items():
|
||||||
|
total += self.create_relationship(rel_type, relations)
|
||||||
|
logger.info(f"总计创建关系: {total} 条")
|
||||||
|
return total
|
||||||
|
|
||||||
|
def export_data(self):
|
||||||
|
"""导出节点到文件,包含 document_id"""
|
||||||
|
nodes_by_label, _, _ = self.read_nodes()
|
||||||
|
os.makedirs('dict', exist_ok=True)
|
||||||
|
for label, nodes in nodes_by_label.items():
|
||||||
|
with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f:
|
||||||
|
f.write('\n'.join(f"{name}\t{self.document_id}" for name in sorted(nodes)))
|
||||||
|
logger.info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)} 个")
|
||||||
|
return
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
data_path = '/share/wangmeihua/rag/triples/26911c68-9107-4bb4-8f31-ff776991a119_testuser2.txt'
|
||||||
|
handler = KnowledgeGraph(data_path)
|
||||||
|
logger.info("Step 1: 导入图谱节点中")
|
||||||
|
handler.create_graphnodes()
|
||||||
|
logger.info("Step 2: 导入图谱边中")
|
||||||
|
handler.create_graphrels()
|
||||||
|
logger.info("Step 3: 导出数据")
|
||||||
|
handler.export_data()
|
||||||
201
rag/query.py
Normal file
201
rag/query.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict
|
||||||
|
from pymilvus import connections, Collection, utility
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from vector import initialize_milvus_connection, cleanup_milvus_connection
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
logger.handlers.clear() # 清除现有处理器,避免重复
|
||||||
|
logger.propagate = False # 禁用传播到父级
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
def search_query(query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, offset: int = 0) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
根据用户输入的查询文本,在指定 db_type 的知识库中搜索与 userid 相关的指定文档。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query (str): 用户输入的查询文本
|
||||||
|
userid (str): 用户ID,用于过滤
|
||||||
|
db_type (str): 数据库类型(例如 'textdb')
|
||||||
|
file_paths (List[str]): 文档路径列表(支持1到多个文件)
|
||||||
|
limit (int): 返回的最大结果数,默认为 10
|
||||||
|
offset (int): 偏移量,用于分页,默认为 0
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[Dict]: 搜索结果,每个元素为包含 text、distance 和 metadata 的字典
|
||||||
|
|
||||||
|
异常:
|
||||||
|
ValueError: 参数无效
|
||||||
|
RuntimeError: 模型加载或 Milvus 操作失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 参数验证
|
||||||
|
if not query:
|
||||||
|
raise ValueError("查询文本不能为空")
|
||||||
|
if not userid or not db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能为空")
|
||||||
|
if "_" in userid or "_" in db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
|
if len(userid) > 100 or len(db_type) > 100:
|
||||||
|
raise ValueError("userid 或 db_type 的长度超出限制")
|
||||||
|
if limit <= 0 or limit > 16384:
|
||||||
|
raise ValueError("limit 必须在 1 到 16384 之间")
|
||||||
|
if offset < 0:
|
||||||
|
raise ValueError("offset 不能为负数")
|
||||||
|
if limit + offset > 16384:
|
||||||
|
raise ValueError("limit + offset 不能超过 16384")
|
||||||
|
if not file_paths:
|
||||||
|
raise ValueError("file_paths 不能为空")
|
||||||
|
for file_path in file_paths:
|
||||||
|
if not isinstance(file_path, str):
|
||||||
|
raise ValueError(f"file_path 必须是字符串: {file_path}")
|
||||||
|
if len(os.path.basename(file_path)) > 255:
|
||||||
|
raise ValueError(f"文件名长度超出 255 个字符: {file_path}")
|
||||||
|
if "_" in os.path.basename(file_path):
|
||||||
|
raise ValueError(f"文件名 {file_path} 不能包含下划线")
|
||||||
|
|
||||||
|
# 初始化嵌入模型
|
||||||
|
model_path = TEXT_EMBEDDING_MODEL
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise ValueError(f"模型路径 {model_path} 不存在")
|
||||||
|
|
||||||
|
embedding = HuggingFaceEmbeddings(
|
||||||
|
model_name=model_path,
|
||||||
|
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
test_vector = embedding.embed_query("test")
|
||||||
|
if len(test_vector) != 1024:
|
||||||
|
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
|
||||||
|
logger.debug("嵌入模型加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
|
||||||
|
# 将查询转换为向量
|
||||||
|
query_vector = embedding.embed_query(query)
|
||||||
|
logger.debug(f"查询向量维度: {len(query_vector)}")
|
||||||
|
|
||||||
|
# 连接到 Milvus
|
||||||
|
initialize_milvus_connection()
|
||||||
|
|
||||||
|
# 检查集合是否存在
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 加载集合
|
||||||
|
try:
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"加载集合: {collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||||
|
|
||||||
|
# 构造搜索参数
|
||||||
|
search_params = {
|
||||||
|
"metric_type": "COSINE", # 与 vector.py 一致
|
||||||
|
"params": {"nprobe": 10} # 优化搜索性能
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构造过滤表达式,限制在指定文件
|
||||||
|
filenames = [os.path.basename(file_path) for file_path in file_paths]
|
||||||
|
filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames])
|
||||||
|
expr = f"userid == '{userid}' and ({filename_expr})"
|
||||||
|
logger.debug(f"搜索参数: {search_params}, 表达式: {expr}, limit: {limit}, offset: {offset}")
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
try:
|
||||||
|
results = collection.search(
|
||||||
|
data=[query_vector],
|
||||||
|
anns_field="vector",
|
||||||
|
param=search_params,
|
||||||
|
limit=limit,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"搜索失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"搜索失败: {str(e)}")
|
||||||
|
|
||||||
|
# 处理搜索结果
|
||||||
|
search_results = []
|
||||||
|
for hits in results:
|
||||||
|
for hit in hits:
|
||||||
|
metadata = {
|
||||||
|
"userid": hit.entity.get("userid"),
|
||||||
|
"document_id": hit.entity.get("document_id"),
|
||||||
|
"filename": hit.entity.get("filename"),
|
||||||
|
"file_path": hit.entity.get("file_path"),
|
||||||
|
"upload_time": hit.entity.get("upload_time"),
|
||||||
|
"file_type": hit.entity.get("file_type")
|
||||||
|
}
|
||||||
|
result = {
|
||||||
|
"text": hit.entity.get("text"),
|
||||||
|
"distance": hit.distance,
|
||||||
|
"metadata": metadata
|
||||||
|
}
|
||||||
|
search_results.append(result)
|
||||||
|
logger.debug(f"命中: text: {result['text'][:200]}..., 距离: {hit.distance}, 元数据: {metadata}")
|
||||||
|
|
||||||
|
logger.debug(f"搜索完成,返回 {len(search_results)} 条结果")
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"搜索失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cleanup_milvus_connection()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试代码
|
||||||
|
query = "知识图谱的知识融合是什么?"
|
||||||
|
userid = "testuser2"
|
||||||
|
db_type = "textdb"
|
||||||
|
file_paths = [
|
||||||
|
"/share/wangmeihua/rag/data/test.docx",
|
||||||
|
"/share/wangmeihua/rag/data/test.txt"
|
||||||
|
]
|
||||||
|
limit = 5
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = search_query(query, userid, db_type, file_paths, limit, offset)
|
||||||
|
print(f"搜索结果 ({len(results)} 条):")
|
||||||
|
for idx, result in enumerate(results, 1):
|
||||||
|
print(f"结果 {idx}:")
|
||||||
|
print(f"内容: {result['text'][:200]}...")
|
||||||
|
print(f"距离: {result['distance']}")
|
||||||
|
print(f"元数据: {result['metadata']}")
|
||||||
|
print("-" * 50)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索失败: {str(e)}")
|
||||||
53
rag/rag.bak.py
Normal file
53
rag/rag.bak.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.configuredServer import ConfiguredServer
|
||||||
|
from ahserver.webapp import webapp
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from query import search_query
|
||||||
|
from embed import embed
|
||||||
|
from deletefile import delete_document
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(os.path.expanduser('~/rag/logs/rag.log'), encoding='utf-8'),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_dbname(name):
|
||||||
|
"""
|
||||||
|
获取默认数据库名称,优先使用环境变量 RAG_DB_TYPE,未设置时返回 'textdb'。
|
||||||
|
"""
|
||||||
|
return os.getenv('RAG_DB_TYPE', 'textdb')
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
"""
|
||||||
|
初始化 RAG 系统,绑定核心函数到 ServerEnv。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info("初始化 RAG 系统")
|
||||||
|
g = ServerEnv()
|
||||||
|
|
||||||
|
# 绑定核心函数为异步版本
|
||||||
|
g.embed = awaitify(embed)
|
||||||
|
g.query = awaitify(search_query)
|
||||||
|
g.delete = awaitify(delete_document)
|
||||||
|
g.get_module_dbname = get_module_dbname
|
||||||
|
|
||||||
|
logger.info("RAG 系统初始化完成")
|
||||||
|
return g
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logger.info("启动 RAG Web 服务器")
|
||||||
|
webapp(init)
|
||||||
80
rag/rerank.py
Normal file
80
rag/rerank.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict
|
||||||
|
from pymilvus.model.reranker import BGERerankFunction
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
logger.handlers.clear() # 清除现有处理器
|
||||||
|
logger.propagate = False # 禁用传播
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
def rerank_results(query: str, results: List[Dict], top_k: int = 10) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
使用 BGE Reranker 模型对查询和文本块进行重排序。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query (str): 查询文本
|
||||||
|
results (List[Dict]): 包含 text、distance、source 和 metadata 的结果列表
|
||||||
|
top_k (int): 返回的最大结果数,默认为 10
|
||||||
|
|
||||||
|
返回:
|
||||||
|
List[Dict]: 重排序后的结果列表,包含 text、distance、source、metadata 和 rerank_score
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 初始化 BGE Reranker
|
||||||
|
bge_rf = BGERerankFunction(
|
||||||
|
model_name="/share/models/BAAI/bge-reranker-v2-m3",
|
||||||
|
device="cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
logger.debug(f"BGE Reranker 初始化成功,模型路径: /share/models/BAAI/bge-reranker-v2-m3, 设备: {'cuda:0' if torch.cuda.is_available() else 'cpu'}")
|
||||||
|
|
||||||
|
# 提取文本块
|
||||||
|
documents = [result['text'] for result in results]
|
||||||
|
if not documents:
|
||||||
|
logger.warning("无文本块可重排序")
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 重排序
|
||||||
|
rerank_results = bge_rf(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
top_k=min(top_k, len(documents))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建重排序结果
|
||||||
|
reranked = []
|
||||||
|
for result in rerank_results:
|
||||||
|
original_result = results[result.index].copy()
|
||||||
|
original_result['rerank_score'] = result.score
|
||||||
|
reranked.append(original_result)
|
||||||
|
logger.debug(f"重排序结果: text={result.text[:200]}..., rerank_score={result.score:.6f}, source={original_result['source']}")
|
||||||
|
|
||||||
|
logger.info(f"重排序返回 {len(reranked)} 条结果")
|
||||||
|
return reranked
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重排序失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
# 回退到原始结果
|
||||||
|
return results
|
||||||
363
rag/searchquery.py
Normal file
363
rag/searchquery.py
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict
|
||||||
|
from pymilvus import connections, Collection, utility
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
import numpy as np
|
||||||
|
from scipy.spatial.distance import cosine
|
||||||
|
from ltp import LTP
|
||||||
|
from vector import initialize_milvus_connection, cleanup_milvus_connection
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载配置文件失败: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
logger.handlers.clear() # 清理现有处理器,避免重复
|
||||||
|
logger.propagate = False # 禁用传播到父级
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# 三元组保存路径
|
||||||
|
TRIPLES_OUTPUT_DIR = '/share/wangmeihua/rag/triples'
|
||||||
|
|
||||||
|
# 初始化嵌入模型
|
||||||
|
embedding = HuggingFaceEmbeddings(
|
||||||
|
model_name=TEXT_EMBEDDING_MODEL,
|
||||||
|
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
test_vector = embedding.embed_query("test")
|
||||||
|
if len(test_vector) != 1024:
|
||||||
|
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
|
||||||
|
logger.debug("嵌入模型加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
|
||||||
|
# 初始化 LTP 模型
|
||||||
|
try:
|
||||||
|
model_path = "/share/models/LTP/small"
|
||||||
|
if not os.path.isdir(model_path):
|
||||||
|
logger.warning(f"本地模型路径 {model_path} 不存在,尝试使用 Hugging Face 模型 'hit-scir/ltp-small'")
|
||||||
|
model_path = "hit-scir/ltp-small"
|
||||||
|
ltp = LTP(pretrained_model_name_or_path=model_path)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
ltp.to("cuda")
|
||||||
|
logger.debug("LTP 模型加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载 LTP 模型失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载 LTP 模型失败: {str(e)}")
|
||||||
|
|
||||||
|
def extract_entities(query: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
从查询文本中抽取实体,包括:
|
||||||
|
- LTP NER 识别的实体(所有类型)。
|
||||||
|
- LTP POS 标注为名词('n')的词。
|
||||||
|
- LTP POS 标注为动词('v')的词。
|
||||||
|
- 连续名词合并(如 '苹果 公司' -> '苹果公司'),移除子词。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not query:
|
||||||
|
raise ValueError("查询文本不能为空")
|
||||||
|
|
||||||
|
# 使用 LTP pipeline 获取分词、词性、NER 结果
|
||||||
|
result = ltp.pipeline([query], tasks=["cws", "pos", "ner"])
|
||||||
|
words = result.cws[0]
|
||||||
|
pos_list = result.pos[0]
|
||||||
|
ner = result.ner[0]
|
||||||
|
|
||||||
|
entities = []
|
||||||
|
subword_set = set() # 记录连续名词的子词
|
||||||
|
|
||||||
|
# 提取 1:NER 实体(所有类型)
|
||||||
|
logger.debug(f"NER 结果: {ner}")
|
||||||
|
for entity_type, entity, start, end in ner:
|
||||||
|
entities.append(entity)
|
||||||
|
|
||||||
|
# 提取 2:合并连续名词
|
||||||
|
combined = ""
|
||||||
|
combined_words = [] # 记录当前连续名词的单词
|
||||||
|
for i in range(len(words)):
|
||||||
|
if pos_list[i] == 'n':
|
||||||
|
combined += words[i]
|
||||||
|
combined_words.append(words[i])
|
||||||
|
if i + 1 < len(words) and pos_list[i + 1] == 'n':
|
||||||
|
continue
|
||||||
|
if combined:
|
||||||
|
entities.append(combined)
|
||||||
|
subword_set.update(combined_words)
|
||||||
|
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}")
|
||||||
|
combined = ""
|
||||||
|
combined_words = []
|
||||||
|
else:
|
||||||
|
combined = ""
|
||||||
|
combined_words = []
|
||||||
|
logger.debug(f"连续名词子词集合: {subword_set}")
|
||||||
|
|
||||||
|
# 提取 3:POS 名词('n'),排除子词
|
||||||
|
for word, pos in zip(words, pos_list):
|
||||||
|
if pos == 'n' and word not in subword_set:
|
||||||
|
entities.append(word)
|
||||||
|
|
||||||
|
# 提取 4:POS 动词('v')
|
||||||
|
for word, pos in zip(words, pos_list):
|
||||||
|
if pos == 'v':
|
||||||
|
entities.append(word)
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
unique_entities = list(dict.fromkeys(entities))
|
||||||
|
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
|
||||||
|
return unique_entities
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"实体抽取失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_triplets_from_file(triplet_file: str) -> List[Dict]:
|
||||||
|
"""从三元组文件中加载"""
|
||||||
|
triplets = []
|
||||||
|
try:
|
||||||
|
if not os.path.exists(triplet_file):
|
||||||
|
logger.warning(f"三元组文件 {triplet_file} 不存在")
|
||||||
|
return []
|
||||||
|
|
||||||
|
with open(triplet_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
parts = line.strip().split('\t')
|
||||||
|
if len(parts) >= 5:
|
||||||
|
head, relation, tail, head_type, tail_type = parts[:5]
|
||||||
|
triplets.append({
|
||||||
|
'head': head,
|
||||||
|
'head_type': head_type,
|
||||||
|
'type': relation,
|
||||||
|
'tail': tail,
|
||||||
|
'tail_type': tail_type
|
||||||
|
})
|
||||||
|
logger.debug(f"从 {triplet_file} 加载 {len(triplets)} 个三元组")
|
||||||
|
return triplets
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载三元组文件 {triplet_file} 失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def match_triplets(query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
匹配查询实体与文档三元组,使用语义嵌入:
|
||||||
|
- 初始匹配:实体与 head 或 tail 相似度 ≥ 0.8。
|
||||||
|
- 返回匹配的三元组。
|
||||||
|
"""
|
||||||
|
matched_triplets = []
|
||||||
|
ENTITY_SIMILARITY_THRESHOLD = 0.8 # 实体与 head/tail 相似度阈值
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载三元组
|
||||||
|
triplet_file = os.path.join(TRIPLES_OUTPUT_DIR, f"{document_id}_{userid}.txt")
|
||||||
|
doc_triplets = load_triplets_from_file(triplet_file)
|
||||||
|
if not doc_triplets:
|
||||||
|
logger.debug(f"文档 document_id={document_id} 无三元组")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 缓存查询实体嵌入
|
||||||
|
entity_vectors = {entity: embedding.embed_query(entity) for entity in query_entities}
|
||||||
|
|
||||||
|
# 初始匹配
|
||||||
|
for entity in query_entities:
|
||||||
|
entity_vec = entity_vectors[entity]
|
||||||
|
for d_triplet in doc_triplets:
|
||||||
|
d_head_vec = embedding.embed_query(d_triplet['head'])
|
||||||
|
d_tail_vec = embedding.embed_query(d_triplet['tail'])
|
||||||
|
head_similarity = 1 - cosine(entity_vec, d_head_vec)
|
||||||
|
tail_similarity = 1 - cosine(entity_vec, d_tail_vec)
|
||||||
|
|
||||||
|
if head_similarity >= ENTITY_SIMILARITY_THRESHOLD or tail_similarity >= ENTITY_SIMILARITY_THRESHOLD:
|
||||||
|
matched_triplets.append(d_triplet)
|
||||||
|
logger.debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} "
|
||||||
|
f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})")
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
unique_matched = []
|
||||||
|
seen = set()
|
||||||
|
for t in matched_triplets:
|
||||||
|
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
|
||||||
|
if identifier not in seen:
|
||||||
|
seen.add(identifier)
|
||||||
|
unique_matched.append(t)
|
||||||
|
|
||||||
|
logger.info(f"找到 {len(unique_matched)} 个匹配的三元组")
|
||||||
|
return unique_matched
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"匹配三元组失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def searchquery(query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, offset: int = 0) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
根据查询抽取实体,匹配指定文档的三元组,并在 Milvus 中搜索相关文档片段。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not query or not userid or not db_type or not file_paths:
|
||||||
|
raise ValueError("query、userid、db_type 和 file_paths 不能为空")
|
||||||
|
if "_" in userid or "_" in db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
|
if len(userid) > 100 or len(db_type) > 100:
|
||||||
|
raise ValueError("userid 或 db_type 的长度超出限制")
|
||||||
|
if limit <= 0 or limit > 16384:
|
||||||
|
raise ValueError("limit 必须在 1 到 16384 之间")
|
||||||
|
if offset < 0:
|
||||||
|
raise ValueError("offset 不能为负数")
|
||||||
|
if limit + offset > 16384:
|
||||||
|
raise ValueError("limit + offset 不能超过 16384")
|
||||||
|
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return []
|
||||||
|
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for file_path in file_paths:
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
results = collection.query(
|
||||||
|
expr=f"userid == '{userid}' and filename == '{filename}'",
|
||||||
|
output_fields=["document_id", "filename"],
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
if not results:
|
||||||
|
logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档")
|
||||||
|
continue
|
||||||
|
documents.append(results[0])
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
logger.warning("没有找到任何有效文档")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}")
|
||||||
|
|
||||||
|
query_entities = extract_entities(query)
|
||||||
|
if not query_entities:
|
||||||
|
logger.warning("未从查询中提取到实体")
|
||||||
|
return []
|
||||||
|
|
||||||
|
search_results = []
|
||||||
|
for doc in documents:
|
||||||
|
document_id = doc["document_id"]
|
||||||
|
filename = doc["filename"]
|
||||||
|
logger.debug(f"处理文档: document_id={document_id}, filename={filename}")
|
||||||
|
|
||||||
|
matched_triplets = match_triplets(query, query_entities, userid, document_id)
|
||||||
|
if not matched_triplets:
|
||||||
|
logger.debug(f"文档 document_id={document_id} 未找到匹配的三元组")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for triplet in matched_triplets:
|
||||||
|
head = triplet['head']
|
||||||
|
type = triplet['type']
|
||||||
|
tail = triplet['tail']
|
||||||
|
if not head or not type or not tail:
|
||||||
|
logger.debug(f"无效三元组: head={head}, type={type}, tail={tail}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
triplet_text = f"{head} {type} {tail}"
|
||||||
|
logger.debug(f"搜索三元组: {triplet_text} (文档: {filename})")
|
||||||
|
try:
|
||||||
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||||
|
query_vector = embedding.embed_query(triplet_text)
|
||||||
|
expr = f"userid == '{userid}' and filename == '{filename}' and text like '%{head}%{tail}%'"
|
||||||
|
logger.debug(f"搜索表达式: {expr}")
|
||||||
|
|
||||||
|
results = collection.search(
|
||||||
|
data=[query_vector],
|
||||||
|
anns_field="vector",
|
||||||
|
param=search_params,
|
||||||
|
limit=limit,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
|
||||||
|
for hits in results:
|
||||||
|
for hit in hits:
|
||||||
|
metadata = {
|
||||||
|
"userid": hit.entity.get("userid"),
|
||||||
|
"document_id": hit.entity.get("document_id"),
|
||||||
|
"filename": hit.entity.get("filename"),
|
||||||
|
"file_path": hit.entity.get("file_path"),
|
||||||
|
"upload_time": hit.entity.get("upload_time"),
|
||||||
|
"file_type": hit.entity.get("file_type")
|
||||||
|
}
|
||||||
|
result = {
|
||||||
|
"text": hit.entity.get("text"),
|
||||||
|
"distance": hit.distance,
|
||||||
|
"metadata": metadata
|
||||||
|
}
|
||||||
|
search_results.append(result)
|
||||||
|
logger.debug(f"命中: text: {result['text'][:200]}..., 距离: {hit.distance}, 元数据: {metadata}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"三元组 {triplet_text} 在文档 {filename} 搜索失败: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
unique_results = []
|
||||||
|
seen_texts = set()
|
||||||
|
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
|
||||||
|
if result['text'] not in seen_texts:
|
||||||
|
unique_results.append(result)
|
||||||
|
seen_texts.add(result['text'])
|
||||||
|
if len(unique_results) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"返回 {len(unique_results)} 条唯一结果")
|
||||||
|
return unique_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"搜索失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
cleanup_milvus_connection()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
query = "什么是知识图谱的知识抽取?"
|
||||||
|
userid = "testuser1"
|
||||||
|
db_type = "textdb"
|
||||||
|
file_paths = [
|
||||||
|
"/share/wangmeihua/rag/data/test.docx",
|
||||||
|
"/share/wangmeihua/rag/data/zongshu.pdf",
|
||||||
|
"/share/wangmeihua/rag/data/qianru.pdf"
|
||||||
|
]
|
||||||
|
limit = 5
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = searchquery(query, userid, db_type, file_paths, limit, offset)
|
||||||
|
print(f"搜索结果 ({len(results)} 条):")
|
||||||
|
for idx, result in enumerate(results, 1):
|
||||||
|
print(f"结果 {idx}:")
|
||||||
|
print(f"内容: {result['text'][:200]}...")
|
||||||
|
print(f"距离: {result['distance']}")
|
||||||
|
print(f"元数据: {result['metadata']}")
|
||||||
|
print("-" * 50)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索失败: {str(e)}")
|
||||||
9
rag/test.py
Normal file
9
rag/test.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from py2neo import Graph,Node,Relationship,NodeMatcher
|
||||||
|
|
||||||
|
username = 'neo4j'
|
||||||
|
password = '261229..wmh'
|
||||||
|
auth = (username, password)
|
||||||
|
graph=Graph("bolt://10.18.34.18:7687", auth = auth)
|
||||||
|
|
||||||
|
book_node=Node('经名',name='十三经')
|
||||||
|
graph.create(book_node)
|
||||||
539
rag/vector.py
Normal file
539
rag/vector.py
Normal file
@ -0,0 +1,539 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
|
||||||
|
from langchain_milvus import Milvus
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||||
|
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger(config['logging']['name'])
|
||||||
|
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||||
|
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def ensure_milvus_directory() -> None:
|
||||||
|
"""确保 Milvus 数据库目录存在"""
|
||||||
|
db_dir = os.path.dirname(MILVUS_DB_PATH)
|
||||||
|
if not os.path.exists(db_dir):
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
logger.debug(f"创建 Milvus 目录: {db_dir}")
|
||||||
|
if not os.access(db_dir, os.W_OK):
|
||||||
|
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
|
||||||
|
|
||||||
|
def initialize_milvus_connection() -> None:
|
||||||
|
"""初始化 Milvus 连接,确保单一连接"""
|
||||||
|
try:
|
||||||
|
if not connections.has_connection("default"):
|
||||||
|
connections.connect("default", uri=MILVUS_DB_PATH)
|
||||||
|
logger.debug(f"已连接到 Milvus Lite,路径: {MILVUS_DB_PATH}")
|
||||||
|
else:
|
||||||
|
logger.debug("已存在 Milvus 连接,跳过重复连接")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"连接 Milvus 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
|
||||||
|
|
||||||
|
def cleanup_milvus_connection() -> None:
|
||||||
|
"""清理 Milvus 连接,确保资源释放"""
|
||||||
|
try:
|
||||||
|
if connections.has_connection("default"):
|
||||||
|
connections.disconnect("default")
|
||||||
|
logger.debug("已断开 Milvus 连接")
|
||||||
|
time.sleep(3)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
|
||||||
|
|
||||||
|
def get_vector_db(userid: str, db_type: str, documents: List[Document]) -> Milvus:
|
||||||
|
"""
|
||||||
|
初始化或访问 Milvus Lite 向量数据库集合,按 db_type 组织,利用 userid 区分用户,document_id 区分文档,并插入文档。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 参数验证
|
||||||
|
if not userid or not db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能为空")
|
||||||
|
if "_" in userid or "_" in db_type:
|
||||||
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
|
if len(userid) > 100 or len(db_type) > 100:
|
||||||
|
raise ValueError("userid 和 db_type 的长度应小于 100")
|
||||||
|
if not documents or not all(isinstance(doc, Document) for doc in documents):
|
||||||
|
raise ValueError("documents 不能为空且必须是 Document 对象列表")
|
||||||
|
required_fields = ["userid", "document_id", "filename", "file_path", "upload_time", "file_type"]
|
||||||
|
for doc in documents:
|
||||||
|
if not all(field in doc.metadata and doc.metadata[field] for field in required_fields):
|
||||||
|
raise ValueError(f"文档元数据缺少必需字段或字段值为空: {doc.metadata}")
|
||||||
|
if doc.metadata["userid"] != userid:
|
||||||
|
raise ValueError(f"文档元数据的 userid {doc.metadata['userid']} 与输入 userid {userid} 不一致")
|
||||||
|
|
||||||
|
ensure_milvus_directory()
|
||||||
|
initialize_milvus_connection()
|
||||||
|
|
||||||
|
# 初始化嵌入模型
|
||||||
|
model_path = TEXT_EMBEDDING_MODEL
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise ValueError(f"模型路径 {model_path} 不存在")
|
||||||
|
|
||||||
|
embedding = HuggingFaceEmbeddings(
|
||||||
|
model_name=model_path,
|
||||||
|
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
test_vector = embedding.embed_query("test")
|
||||||
|
if len(test_vector) != 1024:
|
||||||
|
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
|
||||||
|
logger.debug(f"嵌入模型加载成功,输出维度: {len(test_vector)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载模型失败: {str(e)}")
|
||||||
|
|
||||||
|
# 集合名称
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if len(collection_name) > 255:
|
||||||
|
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
|
||||||
|
logger.debug(f"集合名称: {collection_name}")
|
||||||
|
|
||||||
|
# 定义 schema,包含所有固定字段
|
||||||
|
fields = [
|
||||||
|
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True),
|
||||||
|
FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100),
|
||||||
|
FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36),
|
||||||
|
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||||
|
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
|
||||||
|
FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255),
|
||||||
|
FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024),
|
||||||
|
FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64),
|
||||||
|
FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64),
|
||||||
|
]
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=fields,
|
||||||
|
description=f"{db_type} 数据集合,跨用户使用,包含 document_id 和元数据字段",
|
||||||
|
auto_id=True,
|
||||||
|
primary_field="pk",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查集合是否存在
|
||||||
|
if utility.has_collection(collection_name):
|
||||||
|
try:
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
existing_schema = collection.schema
|
||||||
|
expected_fields = {f.name for f in fields}
|
||||||
|
actual_fields = {f.name for f in existing_schema.fields}
|
||||||
|
vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None)
|
||||||
|
|
||||||
|
schema_compatible = False
|
||||||
|
if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR:
|
||||||
|
dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None
|
||||||
|
schema_compatible = dim == 1024
|
||||||
|
logger.debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, "
|
||||||
|
f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else '无'}, "
|
||||||
|
f"dim={dim if dim is not None else '未定义'}")
|
||||||
|
if not schema_compatible:
|
||||||
|
logger.warning(f"集合 {collection_name} 的 schema 不兼容,原因: "
|
||||||
|
f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or '无'}, "
|
||||||
|
f"vector_field: {vector_field is not None}, "
|
||||||
|
f"dtype: {vector_field.dtype if vector_field else '无'}, "
|
||||||
|
f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}")
|
||||||
|
utility.drop_collection(collection_name)
|
||||||
|
else:
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"集合 {collection_name} 已存在并加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||||
|
|
||||||
|
# 创建新集合
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
try:
|
||||||
|
collection = Collection(collection_name, schema)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="vector",
|
||||||
|
index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="userid",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="document_id",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="filename",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="file_path",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="upload_time",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="file_type",
|
||||||
|
index_params={"index_type": "INVERTED"}
|
||||||
|
)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"成功创建并加载集合: {collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建集合 {collection_name} 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"创建集合失败: {str(e)}")
|
||||||
|
|
||||||
|
# 初始化 Milvus 向量存储
|
||||||
|
try:
|
||||||
|
vector_store = Milvus(
|
||||||
|
embedding_function=embedding,
|
||||||
|
collection_name=collection_name,
|
||||||
|
connection_args={"uri": MILVUS_DB_PATH},
|
||||||
|
drop_old=False,
|
||||||
|
auto_id=True,
|
||||||
|
primary_field="pk",
|
||||||
|
)
|
||||||
|
logger.debug(f"成功初始化 Milvus 向量存储: {collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"初始化向量存储失败: {str(e)}")
|
||||||
|
|
||||||
|
# 插入文档
|
||||||
|
try:
|
||||||
|
logger.debug(f"正在为 userid {userid} 插入 {len(documents)} 个文档到 {collection_name}")
|
||||||
|
for doc in documents:
|
||||||
|
logger.debug(f"插入文档元数据: {doc.metadata}")
|
||||||
|
vector_store.add_documents(documents=documents)
|
||||||
|
logger.debug(f"成功插入 {len(documents)} 个文档")
|
||||||
|
|
||||||
|
# 立即查询验证
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
results = collection.query(
|
||||||
|
expr=f"userid == '{userid}'",
|
||||||
|
output_fields=["pk", "text", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
for result in results:
|
||||||
|
logger.debug(f"插入后查询结果: pk={result['pk']}, document_id={result['document_id']}, "
|
||||||
|
f"metadata={{'filename': '{result['filename']}', 'file_path': '{result['file_path']}', "
|
||||||
|
f"'upload_time': '{result['upload_time']}', 'file_type': '{result['file_type']}'}}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"插入文档失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"插入文档失败: {str(e)}")
|
||||||
|
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cleanup_milvus_connection()
|
||||||
|
|
||||||
|
def get_document_mapping(userid: str, db_type: str) -> Dict[str, Dict]:
|
||||||
|
"""
|
||||||
|
获取指定 userid 和 db_type 下的 document_id 与元数据的映射。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not userid or "_" in userid:
|
||||||
|
raise ValueError("userid 不能为空且不能包含下划线")
|
||||||
|
if not db_type or "_" in db_type:
|
||||||
|
raise ValueError("db_type 不能为空且不能包含下划线")
|
||||||
|
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
results = collection.query(
|
||||||
|
expr=f"userid == '{userid}'",
|
||||||
|
output_fields=["userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
limit=100
|
||||||
|
)
|
||||||
|
mapping = {}
|
||||||
|
for result in results:
|
||||||
|
doc_id = result.get("document_id")
|
||||||
|
if doc_id:
|
||||||
|
mapping[doc_id] = {
|
||||||
|
"userid": result.get("userid", ""),
|
||||||
|
"filename": result.get("filename", ""),
|
||||||
|
"file_path": result.get("file_path", ""),
|
||||||
|
"upload_time": result.get("upload_time", ""),
|
||||||
|
"file_type": result.get("file_type", "")
|
||||||
|
}
|
||||||
|
logger.debug(f"document_id: {doc_id}, metadata: {mapping[doc_id]}")
|
||||||
|
|
||||||
|
logger.debug(f"找到 {len(mapping)} 个文档的映射")
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取文档映射失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"获取文档映射失败: {str(e)}")
|
||||||
|
|
||||||
|
def list_user_collections() -> Dict[str, Dict]:
|
||||||
|
"""
|
||||||
|
列出所有数据库类型(db_type)及其包含的用户(userid)与对应的文档(document_id)映射。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ensure_milvus_directory()
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collections = utility.list_collections()
|
||||||
|
|
||||||
|
db_types_with_data = {}
|
||||||
|
for col in collections:
|
||||||
|
if col.startswith("ragdb_"):
|
||||||
|
db_type = col[len("ragdb_"):]
|
||||||
|
logger.debug(f"处理集合: {col} (db_type: {db_type})")
|
||||||
|
|
||||||
|
collection = Collection(col)
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
batch_size = 1000
|
||||||
|
offset = 0
|
||||||
|
user_document_map = {}
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
results = collection.query(
|
||||||
|
expr="",
|
||||||
|
output_fields=["userid", "document_id"],
|
||||||
|
limit=batch_size,
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
if not results:
|
||||||
|
break
|
||||||
|
for result in results:
|
||||||
|
userid = result.get("userid")
|
||||||
|
doc_id = result.get("document_id")
|
||||||
|
if userid and doc_id:
|
||||||
|
if userid not in user_document_map:
|
||||||
|
user_document_map[userid] = set()
|
||||||
|
user_document_map[userid].add(doc_id)
|
||||||
|
offset += batch_size
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询集合 {col} 的用户和文档失败: {str(e)}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 转换为列表以便序列化
|
||||||
|
user_document_map = {uid: list(doc_ids) for uid, doc_ids in user_document_map.items()}
|
||||||
|
logger.debug(f"集合 {col} 中找到用户和文档映射: {user_document_map}")
|
||||||
|
|
||||||
|
db_types_with_data[db_type] = {
|
||||||
|
"userids": user_document_map
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"可用 db_types 和数据: {db_types_with_data}")
|
||||||
|
return db_types_with_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"列出集合和用户失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def view_collection_details(userid: str) -> None:
|
||||||
|
"""
|
||||||
|
查看特定 userid 在所有集合中的内容和容量,包含 document_id 和元数据。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not userid or "_" in userid:
|
||||||
|
raise ValueError("userid 不能为空且不能包含下划线")
|
||||||
|
|
||||||
|
logger.debug(f"正在查看 userid {userid} 的集合")
|
||||||
|
ensure_milvus_directory()
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collections = utility.list_collections()
|
||||||
|
db_types = [col[len("ragdb_"):] for col in collections if col.startswith("ragdb_")]
|
||||||
|
|
||||||
|
if not db_types:
|
||||||
|
logger.debug(f"未找到任何集合")
|
||||||
|
return
|
||||||
|
|
||||||
|
for db_type in db_types:
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
continue
|
||||||
|
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_pks = collection.query(expr=f"userid == '{userid}'", output_fields=["pk"], limit=10000)
|
||||||
|
num_entities = len(all_pks)
|
||||||
|
results = collection.query(
|
||||||
|
expr=f"userid == '{userid}'",
|
||||||
|
output_fields=["userid","text", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
logger.debug(f"集合 {collection_name} 中 userid {userid} 的文档数: {num_entities}")
|
||||||
|
|
||||||
|
if num_entities == 0:
|
||||||
|
logger.debug(f"集合 {collection_name} 中 userid {userid} 无文档")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"集合 {collection_name} 中 userid {userid} 的内容:")
|
||||||
|
for idx, doc in enumerate(results, 1):
|
||||||
|
metadata = {
|
||||||
|
"userid": doc.get("userid", ""),
|
||||||
|
"filename": doc.get("filename", ""),
|
||||||
|
"file_path": doc.get("file_path", ""),
|
||||||
|
"upload_time": doc.get("upload_time", ""),
|
||||||
|
"file_type": doc.get("file_type", "")
|
||||||
|
}
|
||||||
|
logger.debug(f"文档 {idx}: 内容: {doc.get('text', '')[:200]}..., 元数据: {metadata}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询集合 {collection_name} 的文档失败: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"无法查看 userid {userid} 的集合详情: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def view_vector_data(db_type: str, userid: Optional[str] = None, document_id: Optional[str] = None, limit: int = 100) -> Dict[str, Dict]:
|
||||||
|
"""
|
||||||
|
查看指定 db_type 中的向量数据,可选按 userid 和 document_id 过滤,包含完整元数据和向量。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not db_type or "_" in db_type:
|
||||||
|
raise ValueError("db_type 不能为空且不能包含下划线")
|
||||||
|
if limit <= 0 or limit > 16384:
|
||||||
|
raise ValueError("limit 必须在 1 到 16384 之间")
|
||||||
|
if userid and "_" in userid:
|
||||||
|
raise ValueError("userid 不能包含下划线")
|
||||||
|
if document_id and "_" in document_id:
|
||||||
|
raise ValueError("document_id 不能包含下划线")
|
||||||
|
|
||||||
|
initialize_milvus_connection()
|
||||||
|
collection_name = f"ragdb_{db_type}"
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
logger.warning(f"集合 {collection_name} 不存在")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
logger.debug(f"加载集合: {collection_name}")
|
||||||
|
|
||||||
|
expr = []
|
||||||
|
if userid:
|
||||||
|
expr.append(f"userid == '{userid}'")
|
||||||
|
if document_id:
|
||||||
|
expr.append(f"document_id == '{document_id}'")
|
||||||
|
expr = " && ".join(expr) if expr else ""
|
||||||
|
|
||||||
|
results = collection.query(
|
||||||
|
expr=expr,
|
||||||
|
output_fields=["pk", "text", "document_id", "vector", "filename", "file_path", "upload_time", "file_type"],
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_data = {}
|
||||||
|
for doc in results:
|
||||||
|
pk = doc.get("pk", str(uuid.uuid4()))
|
||||||
|
text = doc.get("text", "")
|
||||||
|
doc_id = doc.get("document_id", "")
|
||||||
|
vector = doc.get("vector", [])
|
||||||
|
metadata = {
|
||||||
|
"filename": doc.get("filename", ""),
|
||||||
|
"file_path": doc.get("file_path", ""),
|
||||||
|
"upload_time": doc.get("upload_time", ""),
|
||||||
|
"file_type": doc.get("file_type", "")
|
||||||
|
}
|
||||||
|
vector_data[pk] = {
|
||||||
|
"text": text,
|
||||||
|
"vector": vector,
|
||||||
|
"document_id": doc_id,
|
||||||
|
"metadata": metadata
|
||||||
|
}
|
||||||
|
logger.debug(f"pk: {pk}, text: {text[:200]}..., document_id: {doc_id}, vector_length: {len(vector)}, metadata: {metadata}")
|
||||||
|
|
||||||
|
logger.debug(f"共找到 {len(vector_data)} 条向量数据")
|
||||||
|
return vector_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查看向量数据失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def main():
|
||||||
|
userid = "testuser1"
|
||||||
|
db_type = "textdb"
|
||||||
|
|
||||||
|
# logger.info("\n测试 1:带文档初始化")
|
||||||
|
# documents = [
|
||||||
|
# Document(
|
||||||
|
# page_content="深度学习是基于深层神经网络的机器学习子集。",
|
||||||
|
# metadata={
|
||||||
|
# "userid": userid,
|
||||||
|
# "document_id": str(uuid.uuid4()),
|
||||||
|
# "filename": "test_doc1.txt",
|
||||||
|
# "file_path": "/path/to/test_doc1.txt",
|
||||||
|
# "upload_time": datetime.now().isoformat(),
|
||||||
|
# "file_type": "txt"
|
||||||
|
# }
|
||||||
|
# ),
|
||||||
|
# Document(
|
||||||
|
# page_content="知识图谱是一个结构化的语义知识库。",
|
||||||
|
# metadata={
|
||||||
|
# "userid": userid,
|
||||||
|
# "document_id": str(uuid.uuid4()),
|
||||||
|
# "filename": "test_doc2.txt",
|
||||||
|
# "file_path": "/path/to/test_doc2.txt",
|
||||||
|
# "upload_time": datetime.now().isoformat(),
|
||||||
|
# "file_type": "txt"
|
||||||
|
# }
|
||||||
|
# ),
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# vector_store = get_vector_db(userid, db_type, documents=documents)
|
||||||
|
# logger.info(f"集合: ragdb_{db_type}")
|
||||||
|
# logger.info(f"成功为 userid {userid} 在 {db_type} 中插入文档")
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.info("\n测试 2:列出所有 db_types 和文档映射")
|
||||||
|
try:
|
||||||
|
db_types = list_user_collections()
|
||||||
|
logger.info(f"可用 db_types 和文档: {db_types}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.info(f"\n测试 3:查看 userid {userid} 的所有集合")
|
||||||
|
try:
|
||||||
|
view_collection_details(userid)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
# logger.info(f"\n测试 4:查看向量数据")
|
||||||
|
# try:
|
||||||
|
# vector_data = view_vector_data(db_type, userid=userid)
|
||||||
|
# logger.info(f"向量数据: {vector_data}")
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.info(f"\n测试 5:获取 userid {userid} 在{db_type}数据库的文档映射")
|
||||||
|
try:
|
||||||
|
mapping = get_document_mapping(userid, db_type)
|
||||||
|
logger.info(f"文档映射: {mapping}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"失败: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
rag/version.py
Normal file
1
rag/version.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.0.1'
|
||||||
4
requirements.txt
Normal file
4
requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
filetxt
|
||||||
|
apppublic
|
||||||
|
sqlor
|
||||||
|
ahserver
|
||||||
52
script/install.sh
Executable file
52
script/install.sh
Executable file
@ -0,0 +1,52 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 检查操作系统
|
||||||
|
OS=$(uname -s)
|
||||||
|
if [[ "$OS" != "Darwin" && "$OS" != "Linux" ]]; then
|
||||||
|
echo "错误:此脚本仅支持 macOS 和 Linux!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 检查依赖文件
|
||||||
|
SERVICE_FILE="rag.service"
|
||||||
|
NGINX_FILE="rag.nginx"
|
||||||
|
if [[ ! -f "$SERVICE_FILE" || ! -f "$NGINX_FILE" ]]; then
|
||||||
|
echo "错误:缺少 $SERVICE_FILE 或 $NGINX_FILE 文件"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 1. 配置服务
|
||||||
|
if [[ "$OS" == "Darwin" ]]; then
|
||||||
|
# macOS: 使用 launchd
|
||||||
|
mkdir -p ~/Library/LaunchAgents
|
||||||
|
cp rag.service ~/Library/LaunchAgents/
|
||||||
|
launchctl load ~/Library/LaunchAgents/rag.service
|
||||||
|
launchctl start rag.service
|
||||||
|
elif [[ "$OS" == "Linux" ]]; then
|
||||||
|
# Linux: 使用 Systemd
|
||||||
|
sudo cp rag.service /etc/systemd/system/
|
||||||
|
sudo systemctl daemon-reload
|
||||||
|
sudo systemctl enable rag.service
|
||||||
|
sudo systemctl start rag.service
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 2. 配置 Nginx
|
||||||
|
if ! command -v nginx &> /dev/null; then
|
||||||
|
echo "安装 Nginx..."
|
||||||
|
if [[ "$OS" == "Darwin" ]]; then
|
||||||
|
brew install nginx
|
||||||
|
elif [[ "$OS" == "Linux" ]]; then
|
||||||
|
sudo apt-get update && sudo apt-get install -y nginx
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 动态检测 Nginx 配置路径
|
||||||
|
NGINX_CONF_DIR="/etc/nginx/sites-enabled"
|
||||||
|
if [[ "$OS" == "Darwin" ]]; then
|
||||||
|
NGINX_CONF_DIR="/usr/local/etc/nginx/sites-enabled"
|
||||||
|
fi
|
||||||
|
mkdir -p "$NGINX_CONF_DIR"
|
||||||
|
cp rag.nginx "$NGINX_CONF_DIR/"
|
||||||
|
nginx -t && nginx -s reload || echo "错误:Nginx 配置重载失败"
|
||||||
|
|
||||||
|
echo "安装完成!"
|
||||||
20
script/killname
Executable file
20
script/killname
Executable file
@ -0,0 +1,20 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
if [ -z "$1" ]; then
|
||||||
|
echo "错误:请提供进程名称"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 查找进程并终止
|
||||||
|
PIDS=$(ps -ef | grep "$1" | grep -v grep | awk '{print $2}')
|
||||||
|
if [ -z "$PIDS" ]; then
|
||||||
|
echo "未找到匹配的进程:$1"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
for PID in $PIDS; do
|
||||||
|
echo "终止进程 $PID"
|
||||||
|
kill -9 "$PID"
|
||||||
|
done
|
||||||
|
|
||||||
|
exit 0
|
||||||
31
script/rag.nginx
Normal file
31
script/rag.nginx
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
server {
|
||||||
|
listen 80;
|
||||||
|
server_name rag.opencomputing.cn;
|
||||||
|
autoindex on;
|
||||||
|
client_max_body_size 20m;
|
||||||
|
proxy_set_header X-Forwarded-Host $host;
|
||||||
|
proxy_set_header X-Forwarded-server $host;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Scheme $scheme;
|
||||||
|
proxy_set_header X-Forwarded-Port $server_port;
|
||||||
|
proxy_set_header X-Forwarded-Url "$scheme://$host:$server_port$request_uri";
|
||||||
|
|
||||||
|
index index.html index.htm;
|
||||||
|
|
||||||
|
location ~^/ip$ {
|
||||||
|
return 200 "$remote_addr";
|
||||||
|
}
|
||||||
|
location / {
|
||||||
|
add_header Access-Control-Allow-Origin *;
|
||||||
|
add_header Access-Control-Allow-Origin *;
|
||||||
|
proxy_set_header X-Forwarded-Host $host;
|
||||||
|
proxy_set_header X-Forwarded-server $host;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Scheme $scheme;
|
||||||
|
proxy_set_header X-Forwarded-Port $server_port;
|
||||||
|
proxy_set_header X-real-ip $remote_addr;
|
||||||
|
proxy_send_timeout 600s;
|
||||||
|
proxy_read_timeout 600s;
|
||||||
|
proxy_pass http://localhost:10098/;
|
||||||
|
}
|
||||||
|
}
|
||||||
19
script/rag.service
Normal file
19
script/rag.service
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=RAG Service
|
||||||
|
Documentation=RAG service to control RAG application
|
||||||
|
After=network.target nginx.service
|
||||||
|
Requires=nginx.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
User=wangmeihua
|
||||||
|
Group=wangmeihua
|
||||||
|
# Type=forking
|
||||||
|
User=wangmeihua
|
||||||
|
WorkingDirectory=/share/wangmeihua/rag
|
||||||
|
ExecStart=/bin/bash /share/wangmeihua/rag/script/rag.sh
|
||||||
|
ExecStop=/bin/bash /share/wangmeihua/rag/script/killname app/ragapp.py
|
||||||
|
Restart=on-failure
|
||||||
|
StandardOutput=append:/var/log/rag/rag.log
|
||||||
|
StandardError=append:/var/log/rag/error.log
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
18
script/rag.sh
Executable file
18
script/rag.sh
Executable file
@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
User=wangmeihua
|
||||||
|
Group=wangmeihua
|
||||||
|
PYTHON=python3
|
||||||
|
RAG_PY="/d/wangmeihua/rag/app/ragapp.py"
|
||||||
|
LOG_DIR="/d/wangmeihua/rag/logs"
|
||||||
|
|
||||||
|
# 验证文件存在
|
||||||
|
if [[ ! -f "$RAG_PY" ]]; then
|
||||||
|
echo "错误:$RAG_PY 不存在"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 终止旧进程
|
||||||
|
"/d/wangmeihua/rag/script/killname" $RAG_PY
|
||||||
|
|
||||||
|
# 启动新进程
|
||||||
|
"$PYTHON" "$RAG_PY" -w "/d/wangmeihua/rag"
|
||||||
46
script/set_env.sh
Normal file
46
script/set_env.sh
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
HOME_DIR="/share/wangmeihua"
|
||||||
|
RAG_DIR="/share/wangmeihua/rag"
|
||||||
|
PYTHON_VERSION="python3"
|
||||||
|
|
||||||
|
# 检查 Python 版本
|
||||||
|
if ! command -v "$PYTHON_VERSION" &> /dev/null; then
|
||||||
|
echo "错误:未找到 Python3"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 检查 requirements.txt
|
||||||
|
if [[ ! -f "${RAG_DIR}/requirements.txt" ]]; then
|
||||||
|
echo "错误:${RAG_DIR}/requirements.txt 不存在"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 创建虚拟环境
|
||||||
|
mkdir -p "${HOME_DIR}/bin"
|
||||||
|
"$PYTHON_VERSION" -m venv "${HOME_DIR}/py3"
|
||||||
|
source "${HOME_DIR}/py3/bin/activate"
|
||||||
|
|
||||||
|
# 备份 .bashrc
|
||||||
|
if [[ -f "${HOME_DIR}/.bashrc" ]]; then
|
||||||
|
cp "${HOME_DIR}/.bashrc" "${HOME_DIR}/.bashrc.bak"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 配置环境变量
|
||||||
|
cat >> "${HOME_DIR}/.bashrc" << EOF
|
||||||
|
export PATH="${HOME_DIR}/bin:${HOME_DIR}/py3/bin:\$PATH"
|
||||||
|
source "${HOME_DIR}/py3/bin/activate"
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install -r "${RAG_DIR}/requirements.txt"
|
||||||
|
if [[ $? -ne 0 ]]; then
|
||||||
|
echo "错误:依赖安装失败"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 复制并授权 killname
|
||||||
|
cp killname "${HOME_DIR}/bin"
|
||||||
|
chmod +x "${HOME_DIR}/bin/killname"
|
||||||
|
|
||||||
|
echo "环境配置完成!"
|
||||||
16
setup.cfg
Normal file
16
setup.cfg
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
[metadata]
|
||||||
|
name=rag
|
||||||
|
version = 0.0.1
|
||||||
|
description = Your project description
|
||||||
|
author = "yu moqing"
|
||||||
|
author_email = "yumoqing@gmail.com"
|
||||||
|
readme = "README.md"
|
||||||
|
license = "MIT"
|
||||||
|
[options]
|
||||||
|
packages = find:
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
install_requires =
|
||||||
|
filetxt
|
||||||
|
apppublic
|
||||||
|
sqlor
|
||||||
|
ahserver
|
||||||
52
setup.py
Executable file
52
setup.py
Executable file
@ -0,0 +1,52 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from rag.version import __version__
|
||||||
|
try:
|
||||||
|
from setuptools import setup
|
||||||
|
except ImportError:
|
||||||
|
from distutils.core import setup
|
||||||
|
required = []
|
||||||
|
with open('requirements.txt', 'r') as f:
|
||||||
|
ls = f.read()
|
||||||
|
required = ls.split('\n')
|
||||||
|
|
||||||
|
with open('rag/version.py', 'r') as f:
|
||||||
|
x = f.read()
|
||||||
|
y = x[x.index("'")+1:]
|
||||||
|
z = y[:y.index("'")]
|
||||||
|
version = z
|
||||||
|
with open("README.md", "r") as fh:
|
||||||
|
long_description = fh.read()
|
||||||
|
|
||||||
|
name = "rag"
|
||||||
|
description = "rag"
|
||||||
|
author = "yumoqing"
|
||||||
|
email = "yumoqing@gmail.com"
|
||||||
|
|
||||||
|
package_data = {}
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="rag",
|
||||||
|
version=version,
|
||||||
|
|
||||||
|
# uncomment the following lines if you fill them out in release.py
|
||||||
|
description=description,
|
||||||
|
author=author,
|
||||||
|
author_email=email,
|
||||||
|
platforms='any',
|
||||||
|
install_requires=required ,
|
||||||
|
packages=[
|
||||||
|
"rag"
|
||||||
|
],
|
||||||
|
package_data=package_data,
|
||||||
|
keywords = [
|
||||||
|
],
|
||||||
|
url="https://github.com/yumoqing/rag",
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
classifiers = [
|
||||||
|
'Operating System :: OS Independent',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'License :: OSI Approved :: MIT License',
|
||||||
|
],
|
||||||
|
)
|
||||||
41
triples/521b2024-f3dc-47af-957d-0b725bf0855d_testuser1.txt
Normal file
41
triples/521b2024-f3dc-47af-957d-0b725bf0855d_testuser1.txt
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
谷歌 industry 搜索引擎 org concept
|
||||||
|
知识图谱 Web 3.0 万维网 concept media
|
||||||
|
Web is a list of 网的 unk time
|
||||||
|
自顶向下 百科类网站 结构化数据源 concept media
|
||||||
|
结构化数据 <org> 关系数据库 concept media
|
||||||
|
非结构化数据 subclass of XML concept org
|
||||||
|
模式层 subclass of 知识图谱 concept media
|
||||||
|
结构化知识库 subclass of 知识图谱 concept misc
|
||||||
|
比尔盖茨 employer 微软 per org
|
||||||
|
5 信息抽取 facet of 数据层 media concept
|
||||||
|
信息抽取 part of 知识图谱 concept media
|
||||||
|
实体识别 subclass of 信息抽取 concept media
|
||||||
|
实体分类体系 part of 112种实体类别 concept misc
|
||||||
|
分类研究 实体类别 面向开放域的实体识别 concept media
|
||||||
|
服务器日志 特征建模 搜索引擎 concept org
|
||||||
|
关系抽取 subclass of Relation Extraction concept unk
|
||||||
|
模式匹配 实体 语料 concept media
|
||||||
|
属性抽取 <misc> 统计机器学习 concept media
|
||||||
|
属性 subclass of 实体 concept misc
|
||||||
|
数据挖掘 subclass of 结构化数据 concept media
|
||||||
|
拼图碎片 非结构化 信息抽取 concept media
|
||||||
|
歧义 used by 实体消歧 concept media
|
||||||
|
共指消解 自然语言处理 信息检索 concept misc
|
||||||
|
外部知识库 结构化数据 知识图谱 concept media
|
||||||
|
数据层的融合 模式层 关系数据库 concept media
|
||||||
|
资源描述框架 <media> 本体构建本体 concept org
|
||||||
|
DB2RDF subclass of 结构化的历史数据 cel date
|
||||||
|
自动化本体构建过程 本体库 数据驱动的自动化方式 concept media
|
||||||
|
阿里 owned by 阿里巴巴 org media
|
||||||
|
上下位关系 阿里巴巴 图谱 concept media
|
||||||
|
腾讯 owned by 阿里巴巴 org concept
|
||||||
|
知识图谱 location 城市 concept loc
|
||||||
|
串联 规则 推理策略的一环 concept media
|
||||||
|
算法 part of 知识库 concept media
|
||||||
|
知识库的更新 subclass of 概念层 concept media
|
||||||
|
知识图谱 part of 数据层 concept media
|
||||||
|
总结 part of 知识图谱 concept media
|
||||||
|
知识图谱 移动个人助理(Siri 智能语义搜索 concept media
|
||||||
|
(Sri) subclass of 的知识 eve unk
|
||||||
|
病毒 知识图谱 埃博拉病毒的症状有哪些 concept media
|
||||||
|
症状 part of 三元组 concept misc
|
||||||
41
triples/c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac_testuser1.txt
Normal file
41
triples/c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac_testuser1.txt
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
谷歌 industry 搜索引擎 org concept
|
||||||
|
知识图谱 Web 3.0 万维网 concept media
|
||||||
|
Web is a list of 网的 unk time
|
||||||
|
自顶向下 百科类网站 结构化数据源 concept media
|
||||||
|
结构化数据 <org> 关系数据库 concept media
|
||||||
|
非结构化数据 subclass of XML concept org
|
||||||
|
模式层 subclass of 知识图谱 concept media
|
||||||
|
结构化知识库 subclass of 知识图谱 concept misc
|
||||||
|
比尔盖茨 employer 微软 per org
|
||||||
|
5 信息抽取 facet of 数据层 media concept
|
||||||
|
信息抽取 part of 知识图谱 concept media
|
||||||
|
实体识别 subclass of 信息抽取 concept media
|
||||||
|
实体分类体系 part of 112种实体类别 concept misc
|
||||||
|
分类研究 实体类别 面向开放域的实体识别 concept media
|
||||||
|
服务器日志 特征建模 搜索引擎 concept org
|
||||||
|
关系抽取 subclass of Relation Extraction concept unk
|
||||||
|
模式匹配 实体 语料 concept media
|
||||||
|
属性抽取 <misc> 统计机器学习 concept media
|
||||||
|
属性 subclass of 实体 concept misc
|
||||||
|
数据挖掘 subclass of 结构化数据 concept media
|
||||||
|
拼图碎片 非结构化 信息抽取 concept media
|
||||||
|
歧义 used by 实体消歧 concept media
|
||||||
|
共指消解 自然语言处理 信息检索 concept misc
|
||||||
|
外部知识库 结构化数据 知识图谱 concept media
|
||||||
|
数据层的融合 模式层 关系数据库 concept media
|
||||||
|
资源描述框架 <media> 本体构建本体 concept org
|
||||||
|
DB2RDF subclass of 结构化的历史数据 cel date
|
||||||
|
自动化本体构建过程 本体库 数据驱动的自动化方式 concept media
|
||||||
|
阿里 owned by 阿里巴巴 org media
|
||||||
|
上下位关系 阿里巴巴 图谱 concept media
|
||||||
|
腾讯 owned by 阿里巴巴 org concept
|
||||||
|
知识图谱 location 城市 concept loc
|
||||||
|
串联 规则 推理策略的一环 concept media
|
||||||
|
算法 part of 知识库 concept media
|
||||||
|
知识库的更新 subclass of 概念层 concept media
|
||||||
|
知识图谱 part of 数据层 concept media
|
||||||
|
总结 part of 知识图谱 concept media
|
||||||
|
知识图谱 移动个人助理(Siri 智能语义搜索 concept media
|
||||||
|
(Sri) subclass of 的知识 eve unk
|
||||||
|
病毒 知识图谱 埃博拉病毒的症状有哪些 concept media
|
||||||
|
症状 part of 三元组 concept misc
|
||||||
35
wwwroot/add.ui
Normal file
35
wwwroot/add.ui
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
{
|
||||||
|
"widgettype": "Form",
|
||||||
|
"options": {
|
||||||
|
"height": "70%",
|
||||||
|
"title": "向知识库添加文件",
|
||||||
|
"description": "支持的文件类型:.txt, .csv, .xlsx, .docx, .pptx, .pdf",
|
||||||
|
"method": "POST",
|
||||||
|
"submit_url": "{{entire_url('api/add')}}",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": "file_path",
|
||||||
|
"uitype": "file",
|
||||||
|
"label": "选择文件",
|
||||||
|
"required": true,
|
||||||
|
"description": "支持格式:txt, csv, xlsx, docx, pptx, pdf"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "userid",
|
||||||
|
"uitype": "str",
|
||||||
|
"label": "用户 ID",
|
||||||
|
"value": "user1",
|
||||||
|
"required": true,
|
||||||
|
"description": "请输入用户 ID(不超过 100 字符)"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "db_type",
|
||||||
|
"uitype": "str",
|
||||||
|
"label": "知识库名称",
|
||||||
|
"required": true,
|
||||||
|
"value": "textdb",
|
||||||
|
"description": "请输入知识库名称(如 textdb, pptdb)"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user