first commit

This commit is contained in:
yumoqing 2025-07-16 15:06:59 +08:00
commit 8cf8b975f7
108 changed files with 70588 additions and 0 deletions

BIN
Milvus/milvus.db Normal file

Binary file not shown.

94
README.md Normal file
View File

@ -0,0 +1,94 @@
# 知识库服务器
本系统为不同的客户提供自我管理的知识库,并在知识库基础上提供知识检索
本系统提供API形式为注册的服务器提供知识服支持不面向最终客户
## 依赖
依赖[这些模块](requirements.txt)
## 安装部署
1. 创建rag用户
2. 登录rag用户
3. 执行以下命令
```
git clone git@git.kaiyuancloud.cn:yumoqing/rag
cd rag/script
./install.sh
```
将项目在用户根目录checkout
3.
## 功能
管理client系统的客户知识库并提供知识查询
每个客户可以创建一到多个独立的知识库,为不同的业务场景提供知识库知识
知识库之间数据相互独立,互不干扰。
## http API
### add
增加知识库文档
#### path
/api/add
#### method
POST
#### 输入
name: authentication
value: Bears ${apikey}
score: headers
name: file_name
value: path of uploaded file
score: data
name: userid
value: userid of client system
score: data
name: kdbname
value: rag kdb name
score: data
#### 输出
### query
查询知识库
#### path
/api/query
#### method
POST
#### 输入
name: authentication
value: Bears ${apikey}
score: headers
name: prompt
value: ${prompt}
score: data
name: userid
value: ${userid}
score: data
name: kdbname
value: ${kdbname}
score: data
#### 输出
```
{
total:返回记录条数,
rows返回记录内容
}
rows有以下属性
content文本内容
distances距离
source文档path
```

BIN
app/.query.py.swp Normal file

Binary file not shown.

Binary file not shown.

57
app/embed.py Normal file
View File

@ -0,0 +1,57 @@
import os
from datetime import datetime
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders import UnstructuredPDFLoader
from langchain_community.document_loaders import UnstructuredWordDocumentLoader
from langchain_community.document_loaders import UnstructuredExcelLoader
from langchain_community.document_loaders import UnstructuredPowerPointLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from appPublic.log import debug
from appPublic.uniqueID import getID
from get_vector_db import get_vector_db
TEMP_FOLDER = os.getenv('TEMP_FOLDER', './_temp')
# Function to check if the uploaded file is allowed (only PDF files)
def allowed_file(filename):
allowed_file_subffix = ['pdf','doc', 'docx','xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt']
return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_file_subffix
# Function to load and split the data from the PDF file
def load_and_split_data(file_path):
# Load the PDF file and split the data into chunks
data = None
if file_path.lower().endswith('.pdf'):
loader = UnstructuredPDFLoader(file_path=file_path)
elif file_path.lower().endswith('.docx') or file_path.lower().endswith('.doc'):
loader = UnstructuredWordDocumentLoader(file_path=file_path)
elif file_path.lower().endswith('.pptx') or file_path.lower().endswith('.pptx'):
loader = UnstructuredPowerPointLoader(file_path=file_path)
elif file_path.lower().endswith('.xlsx') or file_path.lower().endswith('.xls'):
loader = UnstructuredExcelLoader(file_path=file_path)
elif file_path.lower().endswith('.csv'):
loader = CSVLoader(file_path=file_path)
else:
loader = TextLoader(file_path=file_path)
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
chunks = text_splitter.split_documents(data)
return chunks
# Main function to handle the embedding process
def embed(file_path, userid, kdbname):
if allowed_file(file_path):
chunks = load_and_split_data(file_path)
debug(f'{chunks=}')
db = get_vector_db(userid, kdbname)
db.add(
documents=[c.page_content for c in chunks],
metadatas=[c.metadata for c in chunks],
ids=[getID() for c in chunks]
)
return True
return False

0
app/init.py Normal file
View File

22
app/ragapp.py Normal file
View File

@ -0,0 +1,22 @@
from ahserver.serverenv import ServerEnv
from ahserver.configuredServer import ConfiguredServer
from ahserver.webapp import webapp
from appPublic.worker import awaitify
from filemgr.init import load_filemgr
from rbac.init import load_rbac
from appbase.init import load_appbase
from rag.init import load_rag
def get_module_dbname(name):
return 'sage'
def init():
load_rag()
load_appbase()
load_filemgr()
env = ServerEnv()
env.get_module_dbname = get_module_dbname
if __name__ == '__main__':
webapp(init)

View File

@ -0,0 +1,2 @@
from .version import __version__

138
build/lib/rag/deletefile.py Normal file
View File

@ -0,0 +1,138 @@
import logging
import yaml
import os
from pymilvus import connections, Collection, utility
from vector import initialize_milvus_connection
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
def delete_document(db_type: str, userid: str, filename: str) -> bool:
"""
根据 db_typeuserid filename 删除用户的指定文件数据
参数:
db_type (str): 数据库类型 'textdb', 'pptdb'
userid (str): 用户 ID
filename (str): 文件名 'test.docx'
返回:
bool: 删除是否成功
异常:
ValueError: 参数无效
RuntimeError: 数据库操作失败
"""
try:
# 参数验证
if not db_type or "_" in db_type:
raise ValueError("db_type 不能为空且不能包含下划线")
if not userid or "_" in userid:
raise ValueError("userid 不能为空且不能包含下划线")
if not filename:
raise ValueError("filename 不能为空")
if len(db_type) > 100 or len(userid) > 100 or len(filename) > 255:
raise ValueError("db_type、userid 或 filename 的长度超出限制")
# 初始化 Milvus 连接
initialize_milvus_connection()
logger.debug(f"已连接到 Milvus Lite路径: {MILVUS_DB_PATH}")
# 检查集合是否存在
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return False
# 加载集合
try:
collection = Collection(collection_name)
collection.load()
logger.debug(f"加载集合: {collection_name}")
except Exception as e:
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"加载集合失败: {str(e)}")
# 查询匹配的 document_id
expr = f"userid == '{userid}' and filename == '{filename}'"
logger.debug(f"查询表达式: {expr}")
try:
results = collection.query(
expr=expr,
output_fields=["document_id"],
limit=1000
)
if not results:
logger.warning(f"没有找到 userid={userid}, filename={filename} 的记录")
return False
document_ids = list(set(result["document_id"] for result in results if "document_id" in result))
logger.debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}")
except Exception as e:
logger.error(f"查询 document_id 失败: {str(e)}")
raise RuntimeError(f"查询失败: {str(e)}")
# 执行删除
total_deleted = 0
for doc_id in document_ids:
try:
delete_expr = f"userid == '{userid}' and document_id == '{doc_id}'"
logger.debug(f"删除表达式: {delete_expr}")
delete_result = collection.delete(delete_expr)
deleted_count = delete_result.delete_count
total_deleted += deleted_count
logger.info(f"成功删除 document_id={doc_id}{deleted_count} 条记录")
except Exception as e:
logger.error(f"删除 document_id={doc_id} 失败: {str(e)}")
continue
if total_deleted == 0:
logger.warning(f"没有删除任何记录userid={userid}, filename={filename}")
return False
logger.info(f"总计删除 {total_deleted} 条记录userid={userid}, filename={filename}")
return True
except ValueError as ve:
logger.error(f"参数验证失败: {str(ve)}")
return False
except RuntimeError as re:
logger.error(f"数据库操作失败: {str(re)}")
return False
except Exception as e:
logger.error(f"删除文件失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return False
finally:
try:
connections.disconnect("default")
logger.debug("已断开 Milvus 连接")
except Exception as e:
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
if __name__ == "__main__":
# 测试用例
db_type = "textdb"
userid = "testuser4"
filename = "聚类结果1.xlsx"
logger.info(f"测试:删除 userid={userid}, filename={filename} 的文件")
result = delete_document(db_type, userid, filename)
print(f"删除结果: {result}")

178
build/lib/rag/embed.py Normal file
View File

@ -0,0 +1,178 @@
import os
import uuid
import yaml
import logging
from datetime import datetime
from typing import List
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pymilvus import connections
from .vector import get_vector_db
from filetxt.loader import fileloader
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
def load_and_split_data(file_path: str, userid: str) -> List[Document]:
"""
加载文件分片并生成带有元数据的 Document 对象
参数:
file_path (str): 文件路径
userid (str): 用户ID
返回:
List[Document]: 分片后的文档列表
异常:
ValueError: 文件或参数无效
"""
try:
# 验证文件
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
if os.path.getsize(file_path) == 0:
raise ValueError(f"文件 {file_path} 为空")
logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节")
ext = file_path.rsplit('.', 1)[1].lower()
logger.debug(f"文件扩展名: {ext}")
# 使用 fileloader 加载文件内容
logger.debug("开始加载文件")
text = fileloader(file_path)
if not text or not text.strip():
raise ValueError(f"文件 {file_path} 加载为空")
# 创建单个 Document 对象
document = Document(page_content=text)
logger.debug(f"加载完成,生成 1 个文档")
# 分割文本
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
length_function=len,
)
chunks = text_splitter.split_documents([document])
logger.debug(f"分割完成,生成 {len(chunks)} 个文档块")
# 为整个文件生成唯一的 document_id
document_id = str(uuid.uuid4())
# 添加元数据,确保包含所有必需字段
filename = os.path.basename(file_path)
upload_time = datetime.now().isoformat()
documents = []
for i, chunk in enumerate(chunks):
chunk.metadata.update({
'userid': userid,
'document_id': document_id,
'filename': filename,
'file_path': file_path,
'upload_time': upload_time,
'file_type': ext,
'chunk_index': i, # 可选,追踪分片顺序
'source': file_path, # 可选,追踪文件来源
})
# 验证元数据完整性
required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type']
if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields):
raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}")
documents.append(chunk)
logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}")
logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块document_id: {document_id}")
return documents
except Exception as e:
logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
raise ValueError(f"加载或分割文件失败: {str(e)}")
def embed(file_path: str, userid: str, db_type: str) -> bool:
"""
嵌入文件到 Milvus 向量数据库
参数:
file_path (str): 文件路径
userid (str): 用户ID
db_type (str): 数据库类型
返回:
bool: 嵌入是否成功
异常:
ValueError: 参数无效
RuntimeError: 数据库操作失败
"""
try:
# 验证输入
if not userid or not db_type:
raise ValueError("userid 和 db_type 不能为空")
if "_" in userid:
raise ValueError("userid 不能包含下划线")
if "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
ext = file_path.rsplit('.', 1)[1].lower()
if ext not in supported_formats:
logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}")
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
# 加载并分割文件
logger.info(f"开始处理文件 {file_path}userid: {userid}db_type: {db_type}")
chunks = load_and_split_data(file_path, userid)
if not chunks:
logger.error(f"文件 {file_path} 未生成任何文档块")
raise ValueError("未生成任何文档块")
logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块")
logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}")
# 插入到 Milvus
db = get_vector_db(userid, db_type, documents=chunks)
if not db:
logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}")
raise RuntimeError(f"数据库操作失败")
logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}")
return True
except ValueError as ve:
logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}")
return False
except RuntimeError as re:
logger.error(f"嵌入文件 {file_path} 失败: {str(re)}")
return False
except Exception as e:
logger.error(f"嵌入文件 {file_path} 失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return False
if __name__ == "__main__":
test_file = "/share/wangmeihua/rag/data/test.txt"
userid = "testuser4"
db_type = "textdb"
result = embed(test_file, userid, db_type)
print(f"嵌入结果: {result}")

15
build/lib/rag/init.py Normal file
View File

@ -0,0 +1,15 @@
from appPublic.worker import awaitify
from ahserver.serverenv import ServerEnv
from .kdb import add_kdb, add_dir, add_doc, get_all_docs
from .query import search_query
from .embed import embed
def load_rag():
env = ServerEnv()
env.add_kdb = add_kdb
env.query = awaitify(search_query)
env.embed = awaitify(embed)
env.add_dir = add_dir
env.add_doc = add_doc
env.get_all_docs = get_all_docs

81
build/lib/rag/kdb.py Normal file
View File

@ -0,0 +1,81 @@
from traceback import format_exc
from appPublic.uniqueID import getID
from appPublic.timeUtils import curDateString
from appPublic.dictObject import DictObject
from sqlor.dbpools import DBPools
from ahserver.serverenv import get_serverenv
from ahserver.filestorage import FileStorage
async def add_kdb(kdb:dict) -> None:
"""
添加知识库
"""
kdb = DictObject(**kdb)
kdb.parentid=None
if kdb.id is None:
kdb.id = getID()
kdb.entity_type = '0'
kdb.create_date = curDateString()
if kdb.orgid is None:
e = Exception(f'Can not add none orgid kdb')
exception(f'{e}\n{format_exc()}')
raise e
f = get_serverenv('get_module_dbname')
dbname = f('rag')
db = DBPools()
async with db.sqlorContext(dbname) as sor:
await C('kdb', kdb.copy())
async def add_dir(kdb:dict) -> None:
"""
添加子目录
"""
kdb = DictObject(**kdb)
if kdb.parentid is None:
e = Exception(f'Can not add root folder')
exception(f'{e}\n{format_exc()}')
raise e
if kdb.id is None:
kdb.id = getID()
kdb.entity_type = '1'
kdb.create_date = curDateString()
f = get_serverenv('get_module_dbname')
dbname = f('rag')
db = DBPools()
async with db.sqlorContext(dbname) as sor:
await C('kdb', kdb.copy())
async def add_doc(doc:dict) -> None:
"""
添加文档
"""
doc = DictObject(**doc)
if doc.parentid is None:
e = Exception(f'Can not add root document')
exception(f'{e}\n{format_exc()}')
raise e
if doc.id is None:
doc.id = getID()
fs = FileStorage()
doc.realpath = fs.realPath(doc.webpath)
doc.create_date = curDateString()
f = get_serverenv('get_module_dbname')
dbname = f('rag')
db = DBPools()
async with db.sqlorContext(dbname) as sor:
await C('doc', doc.copy())
async def get_all_docs(sor, kdbid):
"""
获取所有kdbid下的文档含子目录的
"""
docs = await sor.R('doc', {'parentid':kdbid})
kdbs = await sor.R('kdb', {'parentid':kdbid})
for kdb in kdbs:
docs1 = await get_all_docs(kdb.id)
docs += docs1
return docs

180
build/lib/rag/query.py Normal file
View File

@ -0,0 +1,180 @@
import os
import yaml
import logging
from typing import List, Dict
from pymilvus import connections, Collection, utility
from langchain_huggingface import HuggingFaceEmbeddings
from .vector import get_vector_db, initialize_milvus_connection, cleanup_milvus_connection
import torch
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
def search_query(query: str, userid: str, db_type: str, limit: int = 10, offset: int = 0) -> List[Dict]:
"""
根据用户输入的查询文本在指定 db_type 的知识库中搜索与 userid 相关的文档
参数:
query (str): 用户输入的查询文本
userid (str): 用户ID用于过滤
db_type (str): 数据库类型例如 'textdb'
limit (int): 返回的最大结果数默认为 10
offset (int): 偏移量用于分页默认为 0
返回:
List[Dict]: 搜索结果每个元素为包含 text metadata 的字典
异常:
ValueError: 参数无效
RuntimeError: 模型加载或 Milvus 操作失败
"""
try:
# 参数验证
if not query:
raise ValueError("查询文本不能为空")
if not userid or not db_type:
raise ValueError("userid 和 db_type 不能为空")
if "_" in userid or "_" in db_type:
raise ValueError("userid 和 db_type 不能包含下划线")
if len(userid) > 100 or len(db_type) > 100:
raise ValueError("userid 和 db_type 的长度应小于 100")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if offset < 0:
raise ValueError("offset 不能为负数")
if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384")
# 初始化嵌入模型
model_path = TEXT_EMBEDDING_MODEL
if not os.path.exists(model_path):
raise ValueError(f"模型路径 {model_path} 不存在")
embedding = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
try:
test_vector = embedding.embed_query("test")
if len(test_vector) != 1024:
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
logger.debug("嵌入模型加载成功")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
# 将查询转换为向量
query_vector = embedding.embed_query(query)
logger.debug(f"查询向量维度: {len(query_vector)}")
# 连接到 Milvus
initialize_milvus_connection()
# 检查集合是否存在
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return []
# 加载集合
try:
collection = Collection(collection_name)
collection.load()
logger.debug(f"加载集合: {collection_name}")
except Exception as e:
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"加载集合失败: {str(e)}")
# 构造搜索参数
search_params = {
"metric_type": "COSINE", # 与 vector.py 一致
"params": {"nprobe": 10} # 优化搜索性能
}
# 构造过滤表达式
expr = f"userid == '{userid}'"
logger.debug(f"搜索参数: {search_params}, 表达式: {expr}, limit: {limit}, offset: {offset}")
# 执行搜索
try:
results = collection.search(
data=[query_vector],
anns_field="vector",
param=search_params,
limit=limit,
expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
offset=offset
)
except Exception as e:
logger.error(f"搜索失败: {str(e)}")
raise RuntimeError(f"搜索失败: {str(e)}")
# 处理搜索结果
search_results = []
for hits in results:
for hit in hits:
metadata = {
"userid": hit.entity.get("userid"),
"document_id": hit.entity.get("document_id"),
"filename": hit.entity.get("filename"),
"file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type")
}
result = {
"text": hit.entity.get("text"),
"distance": hit.distance,
"metadata": metadata
}
search_results.append(result)
logger.debug(f"命中: text: {result['text'][:200]}..., 距离: {hit.distance}, 元数据: {metadata}")
logger.debug(f"搜索完成,返回 {len(search_results)} 条结果")
return search_results
except Exception as e:
logger.error(f"搜索失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
raise
finally:
cleanup_milvus_connection()
if __name__ == "__main__":
# 测试代码
query = "知识图谱的知识融合是什么?"
userid = "testuser1"
db_type = "textdb"
limit = 2
offset = 0
try:
results = search_query(query, userid, db_type, limit, offset)
print(f"搜索结果 ({len(results)} 条):")
for idx, result in enumerate(results, 1):
print(f"结果 {idx}:")
print(f"内容: {result['text'][:200]}...")
print(f"距离: {result['distance']}")
print(f"元数据: {result['metadata']}")
print("-" * 50)
except Exception as e:
print(f"搜索失败: {str(e)}")

53
build/lib/rag/rag.bak.py Normal file
View File

@ -0,0 +1,53 @@
from ahserver.serverenv import ServerEnv
from ahserver.configuredServer import ConfiguredServer
from ahserver.webapp import webapp
from appPublic.worker import awaitify
from query import search_query
from embed import embed
from deletefile import delete_document
import logging
import os
# 配置日志
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(os.path.expanduser('~/rag/logs/rag.log'), encoding='utf-8'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def get_module_dbname(name):
"""
获取默认数据库名称优先使用环境变量 RAG_DB_TYPE未设置时返回 'textdb'
"""
return os.getenv('RAG_DB_TYPE', 'textdb')
def init():
"""
初始化 RAG 系统绑定核心函数到 ServerEnv
"""
try:
logger.info("初始化 RAG 系统")
g = ServerEnv()
# 绑定核心函数为异步版本
g.embed = awaitify(embed)
g.query = awaitify(search_query)
g.delete = awaitify(delete_document)
g.get_module_dbname = get_module_dbname
logger.info("RAG 系统初始化完成")
return g
except Exception as e:
logger.error(f"初始化失败: {str(e)}")
raise
if __name__ == '__main__':
logger.info("启动 RAG Web 服务器")
webapp(init)

53
build/lib/rag/rag.py Normal file
View File

@ -0,0 +1,53 @@
from ahserver.serverenv import ServerEnv
from ahserver.configuredServer import ConfiguredServer
from ahserver.webapp import webapp
from appPublic.worker import awaitify
from query import search_query
from embed import embed
from deletefile import delete_document
import logging
import os
# 配置日志
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(os.path.expanduser('~/rag/logs/rag.log'), encoding='utf-8'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def get_module_dbname(name):
"""
获取默认数据库名称优先使用环境变量 RAG_DB_TYPE未设置时返回 'textdb'
"""
return os.getenv('RAG_DB_TYPE', 'textdb')
def init():
"""
初始化 RAG 系统绑定核心函数到 ServerEnv
"""
try:
logger.info("初始化 RAG 系统")
g = ServerEnv()
# 绑定核心函数为异步版本
g.embed = awaitify(embed)
g.query = awaitify(search_query)
g.delete = awaitify(delete_document)
g.get_module_dbname = get_module_dbname
logger.info("RAG 系统初始化完成")
return g
except Exception as e:
logger.error(f"初始化失败: {str(e)}")
raise
if __name__ == '__main__':
logger.info("启动 RAG Web 服务器")
webapp(init)

539
build/lib/rag/vector.py Normal file
View File

@ -0,0 +1,539 @@
import os
import uuid
import json
import yaml
from datetime import datetime
from typing import List, Dict, Optional
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
from langchain_milvus import Milvus
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
import torch
import logging
import time
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
def ensure_milvus_directory() -> None:
"""确保 Milvus 数据库目录存在"""
db_dir = os.path.dirname(MILVUS_DB_PATH)
if not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
logger.debug(f"创建 Milvus 目录: {db_dir}")
if not os.access(db_dir, os.W_OK):
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
def initialize_milvus_connection() -> None:
"""初始化 Milvus 连接,确保单一连接"""
try:
if not connections.has_connection("default"):
connections.connect("default", uri=MILVUS_DB_PATH)
logger.debug(f"已连接到 Milvus Lite路径: {MILVUS_DB_PATH}")
else:
logger.debug("已存在 Milvus 连接,跳过重复连接")
except Exception as e:
logger.error(f"连接 Milvus 失败: {str(e)}")
raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
def cleanup_milvus_connection() -> None:
"""清理 Milvus 连接,确保资源释放"""
try:
if connections.has_connection("default"):
connections.disconnect("default")
logger.debug("已断开 Milvus 连接")
time.sleep(3)
except Exception as e:
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
def get_vector_db(userid: str, db_type: str, documents: List[Document]) -> Milvus:
"""
初始化或访问 Milvus Lite 向量数据库集合 db_type 组织利用 userid 区分用户document_id 区分文档并插入文档
"""
try:
# 参数验证
if not userid or not db_type:
raise ValueError("userid 和 db_type 不能为空")
if "_" in userid or "_" in db_type:
raise ValueError("userid 和 db_type 不能包含下划线")
if len(userid) > 100 or len(db_type) > 100:
raise ValueError("userid 和 db_type 的长度应小于 100")
if not documents or not all(isinstance(doc, Document) for doc in documents):
raise ValueError("documents 不能为空且必须是 Document 对象列表")
required_fields = ["userid", "document_id", "filename", "file_path", "upload_time", "file_type"]
for doc in documents:
if not all(field in doc.metadata and doc.metadata[field] for field in required_fields):
raise ValueError(f"文档元数据缺少必需字段或字段值为空: {doc.metadata}")
if doc.metadata["userid"] != userid:
raise ValueError(f"文档元数据的 userid {doc.metadata['userid']} 与输入 userid {userid} 不一致")
ensure_milvus_directory()
initialize_milvus_connection()
# 初始化嵌入模型
model_path = TEXT_EMBEDDING_MODEL
if not os.path.exists(model_path):
raise ValueError(f"模型路径 {model_path} 不存在")
embedding = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
try:
test_vector = embedding.embed_query("test")
if len(test_vector) != 1024:
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
logger.debug(f"嵌入模型加载成功,输出维度: {len(test_vector)}")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"加载模型失败: {str(e)}")
# 集合名称
collection_name = f"ragdb_{db_type}"
if len(collection_name) > 255:
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
logger.debug(f"集合名称: {collection_name}")
# 定义 schema包含所有固定字段
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True),
FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100),
FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255),
FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64),
]
schema = CollectionSchema(
fields=fields,
description=f"{db_type} 数据集合,跨用户使用,包含 document_id 和元数据字段",
auto_id=True,
primary_field="pk",
)
# 检查集合是否存在
if utility.has_collection(collection_name):
try:
collection = Collection(collection_name)
existing_schema = collection.schema
expected_fields = {f.name for f in fields}
actual_fields = {f.name for f in existing_schema.fields}
vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None)
schema_compatible = False
if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR:
dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None
schema_compatible = dim == 1024
logger.debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, "
f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else ''}, "
f"dim={dim if dim is not None else '未定义'}")
if not schema_compatible:
logger.warning(f"集合 {collection_name} 的 schema 不兼容,原因: "
f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or ''}, "
f"vector_field: {vector_field is not None}, "
f"dtype: {vector_field.dtype if vector_field else ''}, "
f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}")
utility.drop_collection(collection_name)
else:
collection.load()
logger.debug(f"集合 {collection_name} 已存在并加载成功")
except Exception as e:
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"加载集合失败: {str(e)}")
# 创建新集合
if not utility.has_collection(collection_name):
try:
collection = Collection(collection_name, schema)
collection.create_index(
field_name="vector",
index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"}
)
collection.create_index(
field_name="userid",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="document_id",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="filename",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="file_path",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="upload_time",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="file_type",
index_params={"index_type": "INVERTED"}
)
collection.load()
logger.debug(f"成功创建并加载集合: {collection_name}")
except Exception as e:
logger.error(f"创建集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"创建集合失败: {str(e)}")
# 初始化 Milvus 向量存储
try:
vector_store = Milvus(
embedding_function=embedding,
collection_name=collection_name,
connection_args={"uri": MILVUS_DB_PATH},
drop_old=False,
auto_id=True,
primary_field="pk",
)
logger.debug(f"成功初始化 Milvus 向量存储: {collection_name}")
except Exception as e:
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
raise RuntimeError(f"初始化向量存储失败: {str(e)}")
# 插入文档
try:
logger.debug(f"正在为 userid {userid} 插入 {len(documents)} 个文档到 {collection_name}")
for doc in documents:
logger.debug(f"插入文档元数据: {doc.metadata}")
vector_store.add_documents(documents=documents)
logger.debug(f"成功插入 {len(documents)} 个文档")
# 立即查询验证
collection = Collection(collection_name)
collection.load()
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["pk", "text", "document_id", "filename", "file_path", "upload_time", "file_type"],
limit=10
)
for result in results:
logger.debug(f"插入后查询结果: pk={result['pk']}, document_id={result['document_id']}, "
f"metadata={{'filename': '{result['filename']}', 'file_path': '{result['file_path']}', "
f"'upload_time': '{result['upload_time']}', 'file_type': '{result['file_type']}'}}")
except Exception as e:
logger.error(f"插入文档失败: {str(e)}")
raise RuntimeError(f"插入文档失败: {str(e)}")
return vector_store
except Exception as e:
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
raise
finally:
cleanup_milvus_connection()
def get_document_mapping(userid: str, db_type: str) -> Dict[str, Dict]:
"""
获取指定 userid db_type 下的 document_id 与元数据的映射
"""
try:
if not userid or "_" in userid:
raise ValueError("userid 不能为空且不能包含下划线")
if not db_type or "_" in db_type:
raise ValueError("db_type 不能为空且不能包含下划线")
initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return {}
collection = Collection(collection_name)
collection.load()
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
limit=100
)
mapping = {}
for result in results:
doc_id = result.get("document_id")
if doc_id:
mapping[doc_id] = {
"userid": result.get("userid", ""),
"filename": result.get("filename", ""),
"file_path": result.get("file_path", ""),
"upload_time": result.get("upload_time", ""),
"file_type": result.get("file_type", "")
}
logger.debug(f"document_id: {doc_id}, metadata: {mapping[doc_id]}")
logger.debug(f"找到 {len(mapping)} 个文档的映射")
return mapping
except Exception as e:
logger.error(f"获取文档映射失败: {str(e)}")
raise RuntimeError(f"获取文档映射失败: {str(e)}")
def list_user_collections() -> Dict[str, Dict]:
"""
列出所有数据库类型db_type及其包含的用户userid与对应的文档document_id映射
"""
try:
ensure_milvus_directory()
initialize_milvus_connection()
collections = utility.list_collections()
db_types_with_data = {}
for col in collections:
if col.startswith("ragdb_"):
db_type = col[len("ragdb_"):]
logger.debug(f"处理集合: {col} (db_type: {db_type})")
collection = Collection(col)
collection.load()
batch_size = 1000
offset = 0
user_document_map = {}
while True:
try:
results = collection.query(
expr="",
output_fields=["userid", "document_id"],
limit=batch_size,
offset=offset
)
if not results:
break
for result in results:
userid = result.get("userid")
doc_id = result.get("document_id")
if userid and doc_id:
if userid not in user_document_map:
user_document_map[userid] = set()
user_document_map[userid].add(doc_id)
offset += batch_size
except Exception as e:
logger.error(f"查询集合 {col} 的用户和文档失败: {str(e)}")
break
# 转换为列表以便序列化
user_document_map = {uid: list(doc_ids) for uid, doc_ids in user_document_map.items()}
logger.debug(f"集合 {col} 中找到用户和文档映射: {user_document_map}")
db_types_with_data[db_type] = {
"userids": user_document_map
}
logger.debug(f"可用 db_types 和数据: {db_types_with_data}")
return db_types_with_data
except Exception as e:
logger.error(f"列出集合和用户失败: {str(e)}")
raise
def view_collection_details(userid: str) -> None:
"""
查看特定 userid 在所有集合中的内容和容量包含 document_id 和元数据
"""
try:
if not userid or "_" in userid:
raise ValueError("userid 不能为空且不能包含下划线")
logger.debug(f"正在查看 userid {userid} 的集合")
ensure_milvus_directory()
initialize_milvus_connection()
collections = utility.list_collections()
db_types = [col[len("ragdb_"):] for col in collections if col.startswith("ragdb_")]
if not db_types:
logger.debug(f"未找到任何集合")
return
for db_type in db_types:
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
continue
collection = Collection(collection_name)
collection.load()
try:
all_pks = collection.query(expr=f"userid == '{userid}'", output_fields=["pk"], limit=10000)
num_entities = len(all_pks)
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["userid","text", "document_id", "filename", "file_path", "upload_time", "file_type"],
limit=10
)
logger.debug(f"集合 {collection_name} 中 userid {userid} 的文档数: {num_entities}")
if num_entities == 0:
logger.debug(f"集合 {collection_name} 中 userid {userid} 无文档")
continue
logger.debug(f"集合 {collection_name} 中 userid {userid} 的内容:")
for idx, doc in enumerate(results, 1):
metadata = {
"userid": doc.get("userid", ""),
"filename": doc.get("filename", ""),
"file_path": doc.get("file_path", ""),
"upload_time": doc.get("upload_time", ""),
"file_type": doc.get("file_type", "")
}
logger.debug(f"文档 {idx}: 内容: {doc.get('text', '')[:200]}..., 元数据: {metadata}")
except Exception as e:
logger.error(f"查询集合 {collection_name} 的文档失败: {str(e)}")
continue
except Exception as e:
logger.error(f"无法查看 userid {userid} 的集合详情: {str(e)}")
raise
def view_vector_data(db_type: str, userid: Optional[str] = None, document_id: Optional[str] = None, limit: int = 100) -> Dict[str, Dict]:
"""
查看指定 db_type 中的向量数据可选按 userid document_id 过滤包含完整元数据和向量
"""
try:
if not db_type or "_" in db_type:
raise ValueError("db_type 不能为空且不能包含下划线")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if userid and "_" in userid:
raise ValueError("userid 不能包含下划线")
if document_id and "_" in document_id:
raise ValueError("document_id 不能包含下划线")
initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return {}
collection = Collection(collection_name)
collection.load()
logger.debug(f"加载集合: {collection_name}")
expr = []
if userid:
expr.append(f"userid == '{userid}'")
if document_id:
expr.append(f"document_id == '{document_id}'")
expr = " && ".join(expr) if expr else ""
results = collection.query(
expr=expr,
output_fields=["pk", "text", "document_id", "vector", "filename", "file_path", "upload_time", "file_type"],
limit=limit
)
vector_data = {}
for doc in results:
pk = doc.get("pk", str(uuid.uuid4()))
text = doc.get("text", "")
doc_id = doc.get("document_id", "")
vector = doc.get("vector", [])
metadata = {
"filename": doc.get("filename", ""),
"file_path": doc.get("file_path", ""),
"upload_time": doc.get("upload_time", ""),
"file_type": doc.get("file_type", "")
}
vector_data[pk] = {
"text": text,
"vector": vector,
"document_id": doc_id,
"metadata": metadata
}
logger.debug(f"pk: {pk}, text: {text[:200]}..., document_id: {doc_id}, vector_length: {len(vector)}, metadata: {metadata}")
logger.debug(f"共找到 {len(vector_data)} 条向量数据")
return vector_data
except Exception as e:
logger.error(f"查看向量数据失败: {str(e)}")
raise
def main():
userid = "testuser1"
db_type = "textdb"
# logger.info("\n测试 1带文档初始化")
# documents = [
# Document(
# page_content="深度学习是基于深层神经网络的机器学习子集。",
# metadata={
# "userid": userid,
# "document_id": str(uuid.uuid4()),
# "filename": "test_doc1.txt",
# "file_path": "/path/to/test_doc1.txt",
# "upload_time": datetime.now().isoformat(),
# "file_type": "txt"
# }
# ),
# Document(
# page_content="知识图谱是一个结构化的语义知识库。",
# metadata={
# "userid": userid,
# "document_id": str(uuid.uuid4()),
# "filename": "test_doc2.txt",
# "file_path": "/path/to/test_doc2.txt",
# "upload_time": datetime.now().isoformat(),
# "file_type": "txt"
# }
# ),
# ]
#
# try:
# vector_store = get_vector_db(userid, db_type, documents=documents)
# logger.info(f"集合: ragdb_{db_type}")
# logger.info(f"成功为 userid {userid} 在 {db_type} 中插入文档")
# except Exception as e:
# logger.error(f"失败: {str(e)}")
logger.info("\n测试 2列出所有 db_types 和文档映射")
try:
db_types = list_user_collections()
logger.info(f"可用 db_types 和文档: {db_types}")
except Exception as e:
logger.error(f"失败: {str(e)}")
logger.info(f"\n测试 3查看 userid {userid} 的所有集合")
try:
view_collection_details(userid)
except Exception as e:
logger.error(f"失败: {str(e)}")
# logger.info(f"\n测试 4查看向量数据")
# try:
# vector_data = view_vector_data(db_type, userid=userid)
# logger.info(f"向量数据: {vector_data}")
# except Exception as e:
# logger.error(f"失败: {str(e)}")
logger.info(f"\n测试 5获取 userid {userid}{db_type}数据库的文档映射")
try:
mapping = get_document_mapping(userid, db_type)
logger.info(f"文档映射: {mapping}")
except Exception as e:
logger.error(f"失败: {str(e)}")
if __name__ == "__main__":
main()

1
build/lib/rag/version.py Normal file
View File

@ -0,0 +1 @@
__version__ = '0.0.1'

BIN
conf/Milvus/milvus.db Normal file

Binary file not shown.

60
conf/config.json Executable file
View File

@ -0,0 +1,60 @@
{
"password_key":"!@#$%^&*(*&^%$QWERTYUIqwertyui234567",
"logger":{
"name":"rag",
"levelname":"clientinfo",
"logfile":"$[workdir]$/logs/rag.log"
},
"kdbpath":"$[workdir]$/chroma",
"filesroot":"$[workdir]$/files",
"website":{
"paths":[
["$[workdir]$/wwwroot",""]
],
"client_max_size":20000,
"host":"0.0.0.0",
"port":9190,
"coding":"utf-8",
"indexes":[
"index.html",
"index.tmpl",
"index.ui",
"index.dspy",
"index.md"
],
"startswiths":[
{
"leading":"/idfile",
"registerfunction":"idFileDownload"
}
],
"processors":[
[".ws","ws"],
[".xterm","xterm"],
[".proxy","proxy"],
[".llm", "llm"],
[".llms", "llms"],
[".llma", "llma"],
[".xlsxds","xlsxds"],
[".sqlds","sqlds"],
[".tmpl.js","tmpl"],
[".tmpl.css","tmpl"],
[".html.tmpl","tmpl"],
[".bcrud", "bricks_crud"],
[".tmpl","tmpl"],
[".app","app"],
[".bui","bui"],
[".ui","bui"],
[".dspy","dspy"],
[".md","md"]
],
"session_max_time":3000,
"session_issue_time":2500
},
"langMapping":{
"zh-Hans-CN":"zh-cn",
"zh-CN":"zh-cn",
"en-us":"en",
"en-US":"en"
}
}

17766
conf/logs/milvus.log Normal file

File diff suppressed because it is too large Load Diff

8
conf/milvusconfig.yaml Normal file
View File

@ -0,0 +1,8 @@
database:
milvus_db_path: /share/wangmeihua/rag/conf/Milvus/milvus.db
models:
text_embedding_model: /share/models/BAAI/bge-m3
logging:
name: rag
level: DEBUG
file: /share/wangmeihua/rag/conf/logs/milvus.log

BIN
data/jishu.pdf Normal file

Binary file not shown.

599
data/kg_introduction.txt Normal file

File diff suppressed because one or more lines are too long

BIN
data/qianru.pdf Normal file

Binary file not shown.

BIN
data/test.docx Normal file

Binary file not shown.

1
data/test.txt Normal file
View File

@ -0,0 +1 @@
开元云北京科技有限公司是一家注册于2020年的高科技企业在上海、南京、深圳、济南等地设有分支机构创始团队核心成员来自一流的云计算公司及电信运营商拥有云计算、超算、智算和网络运营专业经验在企业市场均拥有超过十年以上行业经验服务客户超过2万家。公司以自主研发的业务操作支撑系统KBoss为底座打造开放算力应用服务平台open-computing将云计算、算力资源和算力应用进行整合为高校、科研、大模型、AI等政企客户提供专业算力云服务形成“云+网+算+应用”的一体化解决方案。在2021年我们荣幸地成为阿里云计算的合作伙伴致力于提供算力应用、算力网络、算网一体的产品和服务同时为芯片、教育科研等企业提供优质的算力服务。2022年我们与国家超级计算济南中心以及中信网络有限公司签署了战略合作协议并成功推出了“Kboss”算网平台。在2023年我们的平台进一步发展成功引入火山引擎、百度智能云。目前我们已成为阿里云、江苏未来网络集团的战略合作伙伴。同时我们深耕“算力+教育”赛道持续推进高校算力平台项目积极建设学校算力网络节点目前已经成功开拓了27所高校。公司提供新一代算力云应用服务模式通过自主研发的开元算力云应用服务平台整合算力资源和算法应用利用创新算力调度化和确定性网络技术针对现代社会对智能化和数字化需求形成包括算力云服务、算力网络和算力应用的全场景解决方案。旨在为政府和企业提供"技术+资源+场景+运营”的产业互联网算力云应用服务平台实现以算力云服务推动数字经济的发展。开元云科技自成立以来得到了包括工信部、教育部、全国高校学会、国家超算中心以及南京未来网络研究院等政府机构、科研机构的大力支持合作领域包括“东数西算、大科学计算、存算分离、芯算一体及国产工业软件SaaS化”覆盖人工智能、芯片仿真、生物制药、工业仿真、材料研发、精尖制造、海洋勘探以及气象监测等高科技领域。

BIN
data/zongshu.pdf Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
data/聚类结果1.xlsx Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

21
json/kdb.json Normal file
View File

@ -0,0 +1,21 @@
{
"tblname": "permission",
"uitype":"tree",
"title":"知识库",
"params":{
"idField":"id",
"textField":"name",
"sortby":"name",
"editable":true,
"browserfields":{
"alters":{}
},
"edit_exclouded_fields":[],
"parentField":"parentid",
"toolbar":{
},
"binds":[
]
}
}

32781
logs/milvus.log Normal file

File diff suppressed because it is too large Load Diff

14191
logs/rag.log Normal file

File diff suppressed because it is too large Load Diff

0
logs/stderr.log Normal file
View File

BIN
models/doc.xlsx Normal file

Binary file not shown.

BIN
models/kdb.xlsx Normal file

Binary file not shown.

4
pyproject.toml Normal file
View File

@ -0,0 +1,4 @@
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"

129
rag.egg-info/PKG-INFO Normal file
View File

@ -0,0 +1,129 @@
Metadata-Version: 2.4
Name: rag
Version: 0.0.1
Summary: rag
Home-page: https://github.com/yumoqing/rag
Author: yumoqing
Author-email: yumoqing@gmail.com
Platform: any
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Description-Content-Type: text/markdown
Requires-Dist: chromadb
Requires-Dist: langchain
Requires-Dist: langchain_community
Requires-Dist: unstructured
Requires-Dist: langchain-text-splitters
Requires-Dist: unstructured[all-docs]
Requires-Dist: langchain_milvus
Requires-Dist: langchain_huggingface
Requires-Dist: transformers
Requires-Dist: openai
Requires-Dist: torch
Requires-Dist: torchvision
Requires-Dist: pymilvus
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: platform
Dynamic: requires-dist
Dynamic: summary
# 知识库服务器
本系统为不同的客户提供自我管理的知识库,并在知识库基础上提供知识检索
本系统提供API形式为注册的服务器提供知识服支持不面向最终客户
## 依赖
依赖[这些模块](requirements.txt)
## 安装部署
1. 创建rag用户
2. 登录rag用户
3. 执行以下命令
```
git clone git@git.kaiyuancloud.cn:yumoqing/rag
cd rag/script
./install.sh
```
将项目在用户根目录checkout
3.
## 功能
管理client系统的客户知识库并提供知识查询
每个客户可以创建一到多个独立的知识库,为不同的业务场景提供知识库知识
知识库之间数据相互独立,互不干扰。
## http API
### add
增加知识库文档
#### path
/api/add
#### method
POST
#### 输入
name: authentication
value: Bears ${apikey}
score: headers
name: file_name
value: path of uploaded file
score: data
name: userid
value: userid of client system
score: data
name: kdbname
value: rag kdb name
score: data
#### 输出
### query
查询知识库
#### path
/api/query
#### method
POST
#### 输入
name: authentication
value: Bears ${apikey}
score: headers
name: prompt
value: ${prompt}
score: data
name: userid
value: ${userid}
score: data
name: kdbname
value: ${kdbname}
score: data
#### 输出
```
{
total:返回记录条数,
rows返回记录内容
}
rows有以下属性
content文本内容
distances距离
source文档path
```

16
rag.egg-info/SOURCES.txt Normal file
View File

@ -0,0 +1,16 @@
README.md
setup.py
rag/__init__.py
rag/deletefile.py
rag/embed.py
rag/init.py
rag/kdb.py
rag/query.py
rag/rag.bak.py
rag/vector.py
rag/version.py
rag.egg-info/PKG-INFO
rag.egg-info/SOURCES.txt
rag.egg-info/dependency_links.txt
rag.egg-info/requires.txt
rag.egg-info/top_level.txt

View File

@ -0,0 +1 @@

13
rag.egg-info/requires.txt Normal file
View File

@ -0,0 +1,13 @@
chromadb
langchain
langchain_community
unstructured
langchain-text-splitters
unstructured[all-docs]
langchain_milvus
langchain_huggingface
transformers
openai
torch
torchvision
pymilvus

View File

@ -0,0 +1 @@
rag

2
rag/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .version import __version__

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

290
rag/allfusedsearch.py Normal file
View File

@ -0,0 +1,290 @@
import os
import logging
import yaml
import numpy as np
from typing import List, Dict, Any
from pymilvus import Collection, utility
from langchain_huggingface import HuggingFaceEmbeddings
from vector import initialize_milvus_connection
from searchquery import extract_entities, match_triplets
from rerank import rerank_results
import torch
import time
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
logger.handlers.clear()
logger.propagate = False
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
try:
with open(config['logging']['file'], 'a', encoding='utf-8') as f:
pass
except Exception as e:
raise RuntimeError(f"日志文件 {config['logging']['file']} 不可写: {str(e)}")
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
# 初始化嵌入模型
embedding = HuggingFaceEmbeddings(
model_name=TEXT_EMBEDDING_MODEL,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
try:
test_vector = embedding.embed_query("test")
if len(test_vector) != 1024:
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
logger.debug("嵌入模型加载成功")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
# 缓存三元组
TRIPLET_CACHE = {}
def load_triplets_to_cache(userid: str, document_id: str) -> List[Dict]:
"""加载三元组到缓存"""
cache_key = f"{document_id}_{userid}"
if cache_key in TRIPLET_CACHE:
logger.debug(f"从缓存加载三元组: {cache_key}")
return TRIPLET_CACHE[cache_key]
triplet_file = f"/share/wangmeihua/rag/triples/{document_id}_{userid}.txt"
triplets = []
try:
with open(triplet_file, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split('\t')
if len(parts) < 3:
continue
head, type_, tail = parts[:3]
triplets.append({'head': head, 'type': type_, 'tail': tail})
TRIPLET_CACHE[cache_key] = triplets
logger.debug(f"加载三元组文件: {triplet_file}, 数量: {len(triplets)}")
return triplets
except Exception as e:
logger.error(f"加载三元组失败: {triplet_file}, 错误: {str(e)}")
return []
def fused_search(
query: str,
userid: str,
db_type: str,
file_paths: List[str],
limit: int = 5,
offset: int = 0,
use_rerank: bool = True
) -> List[Dict[str, Any]]:
"""
融合 RAG 和三元组召回文本块
- 收集所有输入文件的三元组拼接为融合文本向量化后在所有文件中搜索
- 结果去重并按 rerank_score distance 排序重排序使用融合文本
参数:
query (str): 查询文本
userid (str): 用户 ID
db_type (str): 数据库类型 (e.g., 'textdb')
file_paths (List[str]): 文件路径列表
limit (int): 返回结果数量
offset (int): 偏移量
use_rerank (bool): 是否使用重排序
返回:
List[Dict[str, Any]]: 召回结果包含 textdistancesourcemetadatarerank_score
"""
try:
logger.info(f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}")
start_time = time.time()
# 参数验证
if not query or not userid or not db_type or not file_paths:
raise ValueError("query、userid、db_type 和 file_paths 不能为空")
if "_" in userid or "_" in db_type:
raise ValueError("userid 和 db_type 不能包含下划线")
# 初始化 Milvus 连接
connections = initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return []
collection = Collection(collection_name)
collection.load()
logger.debug(f"加载 Milvus 集合: {collection_name}")
# 提取实体
entity_start = time.time()
query_entities = extract_entities(query)
logger.debug(f"提取实体: {query_entities}, 耗时: {time.time() - entity_start:.3f}s")
# 收集所有文件的 document_id 和三元组
doc_id_map = {}
filenames = []
all_triplets = []
for file_path in file_paths:
filename = os.path.basename(file_path)
filenames.append(filename)
logger.debug(f"处理文件: {filename}")
# 获取 document_id
results_query = collection.query(
expr=f"userid == '{userid}' and filename == '{filename}'",
output_fields=["document_id"],
limit=1
)
if not results_query:
logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档")
continue
document_id = results_query[0]["document_id"]
doc_id_map[filename] = document_id
load_triplets_to_cache(userid, document_id)
# 获取匹配的三元组
triplet_start = time.time()
matched_triplets = match_triplets(query, query_entities, userid, document_id)
logger.debug(
f"文件 {filename} 匹配三元组: {len(matched_triplets)} 条, 耗时: {time.time() - triplet_start:.3f}s")
all_triplets.extend(matched_triplets)
if not doc_id_map:
logger.warning("未找到任何有效文档")
return []
# 拼接融合文本
triplet_texts = []
for triplet in all_triplets:
head = triplet['head']
type_ = triplet['type']
tail = triplet['tail']
if not head or not type_ or not tail:
logger.debug(f"无效三元组: {triplet}")
continue
triplet_texts.append(f"{head} {type_} {tail}")
# 定义融合文本
fused_text = query if not triplet_texts else f"{query} {' '.join(triplet_texts)}"
logger.debug(f"融合文本: {fused_text}, 三元组数量: {len(triplet_texts)}")
# 向量化
embed_start = time.time()
query_vector = embedding.embed_query(fused_text)
query_vector = np.array(query_vector) / np.linalg.norm(query_vector)
logger.debug(f"生成融合向量,维度: {len(query_vector)}, 耗时: {time.time() - embed_start:.3f}s")
# Milvus 搜索
expr = f"userid == '{userid}' and filename in {filenames}"
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
milvus_start = time.time()
milvus_results = collection.search(
data=[query_vector],
anns_field="vector",
param=search_params,
limit=100,
expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
offset=offset
)
logger.debug(f"Milvus 搜索耗时: {time.time() - milvus_start:.3f}s")
results = []
for hits in milvus_results:
for hit in hits:
result = {
"text": hit.entity.get("text"),
"distance": hit.distance,
"source": "fused_query" if not triplet_texts else f"fused_triplets_{len(triplet_texts)}",
"metadata": {
"userid": hit.entity.get("userid"),
"document_id": hit.entity.get("document_id"),
"filename": hit.entity.get("filename"),
"file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type")
}
}
results.append(result)
logger.debug(
f"召回: text={result['text'][:100]}..., distance={result['distance']}, filename={result['metadata']['filename']}")
# 去重
unique_results = []
seen_texts = set()
for result in results:
text = result['text']
if not text:
logger.warning(f"发现空文本结果: {result['metadata']}")
continue
if text in seen_texts:
logger.debug(f"移除重复文本: text={text[:100]}..., filename={result['metadata']['filename']}")
continue
seen_texts.add(text)
unique_results.append(result)
logger.info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(results)})")
# 可选:重排序
if use_rerank and unique_results:
logger.debug("开始重排序")
logger.debug(f"重排序查询: {fused_text}")
rerank_start = time.time()
reranked_results = rerank_results(fused_text, unique_results)
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
logger.debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
logger.debug(f"重排序耗时: {time.time() - rerank_start:.3f}s")
for i, result in enumerate(reranked_results):
logger.debug(
f"排序结果 {i + 1}: text={result['text'][:100]}..., distance={result['distance']}, rerank_score={result.get('rerank_score', 'N/A')}")
logger.info(f"总耗时: {time.time() - start_time:.3f}s")
return reranked_results[:limit]
# 按 distance 降序排序
sorted_results = sorted(unique_results, key=lambda x: x['distance'], reverse=True)
for i, result in enumerate(sorted_results):
logger.debug(f"排序结果 {i + 1}: text={result['text'][:100]}..., distance={result['distance']}")
logger.info(f"总耗时: {time.time() - start_time:.3f}s")
return sorted_results[:limit]
except Exception as e:
logger.error(f"融合搜索失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return []
if __name__ == "__main__":
query = "什么是知识抽取?"
userid = "testuser1"
db_type = "textdb"
file_paths = [
"/share/wangmeihua/rag/data/test.docx",
"/share/wangmeihua/rag/data/zongshu.pdf",
"/share/wangmeihua/rag/data/qianru.pdf",
]
try:
results = fused_search(query, userid, db_type, file_paths, limit=10, offset=0)
for i, result in enumerate(results):
print(f"Result {i + 1}:")
print(f"Text: {result['text'][:200]}...")
print(f"Distance: {result['distance']:.3f}")
print(
f"Rerank Score: {result.get('rerank_score', 'N/A') if isinstance(result.get('rerank_score'), (int, float)) else 'N/A':.3f}")
print(f"Source: {result['source']}")
print(f"Metadata: {result['metadata']}\n")
except Exception as e:
print(f"搜索失败: {str(e)}")

190
rag/combinedsearch.py Normal file
View File

@ -0,0 +1,190 @@
import os
import yaml
import logging
from typing import List, Dict
from pymilvus import connections, Collection, utility
from langchain_huggingface import HuggingFaceEmbeddings
from query import search_query
from searchquery import searchquery
from rerank import rerank_results
from vector import initialize_milvus_connection, cleanup_milvus_connection
import torch
from functools import lru_cache
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
logger.handlers.clear() # 清除现有处理器
logger.propagate = False # 禁用传播
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
# 初始化嵌入模型(缓存)
@lru_cache(maxsize=1000)
def get_embedding(text: str) -> List[float]:
embedding = HuggingFaceEmbeddings(
model_name=TEXT_EMBEDDING_MODEL,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
vector = embedding.embed_query(text)
if len(vector) != 1024:
raise ValueError(f"嵌入模型输出维度 {len(vector)} 不匹配预期 1024")
return vector
def combined_search(query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 10, offset: int = 0) -> List[Dict]:
"""
结合 RAG 和三元组检索召回相关文本块使用 BGE Reranker 重排序
参数:
query (str): 查询文本
userid (str): 用户ID
db_type (str): 数据库类型
file_paths (List[str]): 文档路径列表
limit (int): 返回的最大结果数默认为 10
offset (int): 偏移量默认为 0
返回:
List[Dict]: 包含 textdistancesourcemetadata rerank_score 的结果列表
"""
try:
# 参数验证
if not query or not userid or not db_type or not file_paths:
raise ValueError("query、userid、db_type 和 file_paths 不能为空")
if "_" in userid or "_" in db_type:
raise ValueError("userid 和 db_type 不能包含下划线")
if len(userid) > 100 or len(db_type) > 100:
raise ValueError("userid 或 db_type 的长度超出限制")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if offset < 0:
raise ValueError("offset 不能为负数")
if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384")
for file_path in file_paths:
if not isinstance(file_path, str):
raise ValueError(f"file_path 必须是字符串: {file_path}")
if len(os.path.basename(file_path)) > 255:
raise ValueError(f"文件名长度超出 255 个字符: {file_path}")
# 初始化 Milvus 连接
initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return []
# RAG 检索,使用默认 limit=3
rag_results = search_query(query, userid, db_type, file_paths, offset=offset)
for result in rag_results:
result['source'] = 'rag'
logger.info(f"RAG 检索返回 {len(rag_results)} 条结果")
# 三元组检索,使用默认 limit=3
triplet_results = searchquery(query, userid, db_type, file_paths, offset=offset)
for result in triplet_results:
result['source'] = 'triplet'
logger.info(f"三元组检索返回 {len(triplet_results)} 条结果")
# 记录三元组检索结果详情
for idx, result in enumerate(triplet_results, 1):
logger.debug(f"三元组结果 {idx}: text={result['text'][:200]}..., distance={result['distance']:.4f}, metadata={result['metadata']}")
# 合并结果
all_results = rag_results + triplet_results
if not all_results:
logger.warning("RAG 和三元组检索均无结果")
return []
# 记录合并前的结果
logger.debug("合并前结果:")
for idx, result in enumerate(all_results, 1):
logger.debug(f"结果 {idx} ({result['source']}): text={result['text'][:200]}..., distance={result['distance']:.4f}, metadata={result['metadata']}")
# 使用 BGE Reranker 重排序
reranked_results = rerank_results(query, all_results, top_k=len(all_results))
# 按 rerank_score 排序(不去重)
sorted_results = sorted(reranked_results, key=lambda x: x['rerank_score'], reverse=True)
# 记录排序后的结果
logger.debug("重排序后结果:")
for idx, result in enumerate(sorted_results, 1):
logger.debug(f"排序结果 {idx} ({result['source']}): text={result['text'][:200]}..., distance={result['distance']:.4f}, rerank_score={result['rerank_score']:.6f}, metadata={result['metadata']}")
# 去重(基于 text保留 rerank_score 最大的记录)
unique_results = []
text_to_result = {}
for result in sorted_results:
text = result['text']
if text not in text_to_result or result['rerank_score'] > text_to_result[text]['rerank_score']:
text_to_result[text] = result
unique_results = list(text_to_result.values())
# 记录去重后的结果
logger.debug("去重后结果:")
for idx, result in enumerate(unique_results, 1):
logger.debug(f"去重结果 {idx} ({result['source']}): text={result['text'][:200]}..., distance={result['distance']:.4f}, rerank_score={result['rerank_score']:.6f}, metadata={result['metadata']}")
# 限制结果数量
final_results = unique_results[:limit]
logger.info(f"合并后返回 {len(final_results)} 条唯一结果")
# 移除 weighted_score 字段(若存在),保留 rerank_score 和 source
for result in final_results:
result.pop('weighted_score', None)
return final_results
except Exception as e:
logger.error(f"合并搜索失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return []
finally:
cleanup_milvus_connection()
if __name__ == "__main__":
# 测试代码
query = "知识图谱构建需要什么技术?"
userid = "testuser1"
db_type = "textdb"
file_paths = [
"/share/wangmeihua/rag/data/test.docx",
"/share/wangmeihua/rag/data/zongshu.pdf",
"/share/wangmeihua/rag/data/qianru.pdf"
]
limit = 10
offset = 0
try:
results = combined_search(query, userid, db_type, file_paths, limit, offset)
print(f"搜索结果 ({len(results)} 条):")
for idx, result in enumerate(results, 1):
print(f"结果 {idx}:")
print(f"内容: {result['text'][:200]}...")
print(f"距离: {result['distance']}")
print(f"来源: {result['source']}")
print(f"重排序分数: {result['rerank_score']}")
print(f"元数据: {result['metadata']}")
print("-" * 50)
except Exception as e:
print(f"搜索失败: {str(e)}")

138
rag/deletefile.py Normal file
View File

@ -0,0 +1,138 @@
import logging
import yaml
import os
from pymilvus import connections, Collection, utility
from vector import initialize_milvus_connection
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
def delete_document(db_type: str, userid: str, filename: str) -> bool:
"""
根据 db_typeuserid filename 删除用户的指定文件数据
参数:
db_type (str): 数据库类型 'textdb', 'pptdb'
userid (str): 用户 ID
filename (str): 文件名 'test.docx'
返回:
bool: 删除是否成功
异常:
ValueError: 参数无效
RuntimeError: 数据库操作失败
"""
try:
# 参数验证
if not db_type or "_" in db_type:
raise ValueError("db_type 不能为空且不能包含下划线")
if not userid or "_" in userid:
raise ValueError("userid 不能为空且不能包含下划线")
if not filename:
raise ValueError("filename 不能为空")
if len(db_type) > 100 or len(userid) > 100 or len(filename) > 255:
raise ValueError("db_type、userid 或 filename 的长度超出限制")
# 初始化 Milvus 连接
initialize_milvus_connection()
logger.debug(f"已连接到 Milvus Lite路径: {MILVUS_DB_PATH}")
# 检查集合是否存在
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return False
# 加载集合
try:
collection = Collection(collection_name)
collection.load()
logger.debug(f"加载集合: {collection_name}")
except Exception as e:
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"加载集合失败: {str(e)}")
# 查询匹配的 document_id
expr = f"userid == '{userid}' and filename == '{filename}'"
logger.debug(f"查询表达式: {expr}")
try:
results = collection.query(
expr=expr,
output_fields=["document_id"],
limit=1000
)
if not results:
logger.warning(f"没有找到 userid={userid}, filename={filename} 的记录")
return False
document_ids = list(set(result["document_id"] for result in results if "document_id" in result))
logger.debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}")
except Exception as e:
logger.error(f"查询 document_id 失败: {str(e)}")
raise RuntimeError(f"查询失败: {str(e)}")
# 执行删除
total_deleted = 0
for doc_id in document_ids:
try:
delete_expr = f"userid == '{userid}' and document_id == '{doc_id}'"
logger.debug(f"删除表达式: {delete_expr}")
delete_result = collection.delete(delete_expr)
deleted_count = delete_result.delete_count
total_deleted += deleted_count
logger.info(f"成功删除 document_id={doc_id}{deleted_count} 条记录")
except Exception as e:
logger.error(f"删除 document_id={doc_id} 失败: {str(e)}")
continue
if total_deleted == 0:
logger.warning(f"没有删除任何记录userid={userid}, filename={filename}")
return False
logger.info(f"总计删除 {total_deleted} 条记录userid={userid}, filename={filename}")
return True
except ValueError as ve:
logger.error(f"参数验证失败: {str(ve)}")
return False
except RuntimeError as re:
logger.error(f"数据库操作失败: {str(re)}")
return False
except Exception as e:
logger.error(f"删除文件失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return False
finally:
try:
connections.disconnect("default")
logger.debug("已断开 Milvus 连接")
except Exception as e:
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
if __name__ == "__main__":
# 测试用例
db_type = "textdb"
userid = "testuser2"
filename = "test.docx"
logger.info(f"测试:删除 userid={userid}, filename={filename} 的文件")
result = delete_document(db_type, userid, filename)
print(f"删除结果: {result}")

1
rag/dict/cel.txt Normal file
View File

@ -0,0 +1 @@
DB2RDF c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

33
rag/dict/concept.txt Normal file
View File

@ -0,0 +1,33 @@
上下位关系 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
串联 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
信息抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
共指消解 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
关系抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
分类研究 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
外部知识库 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
实体分类体系 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
实体识别 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
属性 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
属性抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
总结 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
拼图碎片 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
搜索引擎 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
数据层 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
数据层的融合 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
数据挖掘 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
服务器日志 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
模式匹配 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
模式层 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
歧义 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
病毒 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
症状 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
知识图谱 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
知识库的更新 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
算法 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
结构化数据 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
结构化知识库 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
自动化本体构建过程 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
自顶向下 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
资源描述框架 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
阿里巴巴 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
非结构化数据 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

1
rag/dict/date.txt Normal file
View File

@ -0,0 +1 @@
结构化的历史数据 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

1
rag/dict/eve.txt Normal file
View File

@ -0,0 +1 @@
(Sri) c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

1
rag/dict/loc.txt Normal file
View File

@ -0,0 +1 @@
城市 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

20
rag/dict/media.txt Normal file
View File

@ -0,0 +1,20 @@
5 信息抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
万维网 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
信息抽取 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
关系数据库 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
图谱 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
埃博拉病毒的症状有哪些 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
实体消歧 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
推理策略的一环 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
数据层 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
数据驱动的自动化方式 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
智能语义搜索 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
概念层 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
知识图谱 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
知识库 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
结构化数据 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
结构化数据源 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
统计机器学习 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
语料 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
阿里巴巴 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
面向开放域的实体识别 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

5
rag/dict/misc.txt Normal file
View File

@ -0,0 +1,5 @@
112种实体类别 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
三元组 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
信息检索 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
实体 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
知识图谱 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

7
rag/dict/org.txt Normal file
View File

@ -0,0 +1,7 @@
XML c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
微软 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
搜索引擎 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
本体构建本体 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
腾讯 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
谷歌 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
阿里 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

1
rag/dict/per.txt Normal file
View File

@ -0,0 +1 @@
比尔盖茨 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

1
rag/dict/time.txt Normal file
View File

@ -0,0 +1 @@
网的 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

3
rag/dict/unk.txt Normal file
View File

@ -0,0 +1,3 @@
Relation Extraction c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
Web c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
的知识 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac

183
rag/embed.py Normal file
View File

@ -0,0 +1,183 @@
import os
import uuid
import yaml
import logging
from datetime import datetime
from typing import List
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pymilvus import connections
from vector import get_vector_db
from filetxt.loader import fileloader
from extract import extract_and_save_triplets
from kgc import KnowledgeGraph
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
except Exception as e:
logger.error(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
logger.handlers.clear()
logger.propagate = False
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
def generate_document_id() -> str:
"""为文件生成唯一的 document_id"""
return str(uuid.uuid4())
def load_and_split_data(file_path: str, userid: str, document_id: str) -> List[Document]:
"""
加载文件分片并生成带有元数据的 Document 对象
"""
try:
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
if os.path.getsize(file_path) == 0:
raise ValueError(f"文件 {file_path} 为空")
logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节")
ext = file_path.rsplit('.', 1)[1].lower()
logger.debug(f"文件扩展名: {ext}")
logger.debug("开始加载文件")
text = fileloader(file_path)
if not text or not text.strip():
raise ValueError(f"文件 {file_path} 加载为空")
document = Document(page_content=text)
logger.debug(f"加载完成,生成 1 个文档")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
length_function=len,
)
chunks = text_splitter.split_documents([document])
logger.debug(f"分割完成,生成 {len(chunks)} 个文档块")
filename = os.path.basename(file_path)
upload_time = datetime.now().isoformat()
documents = []
for i, chunk in enumerate(chunks):
chunk.metadata.update({
'userid': userid,
'document_id': document_id,
'filename': filename,
'file_path': file_path,
'upload_time': upload_time,
'file_type': ext,
'chunk_index': i,
'source': file_path,
})
required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type']
if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields):
raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}")
documents.append(chunk)
logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}")
logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块document_id: {document_id}")
return documents
except Exception as e:
logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
raise ValueError(f"加载或分割文件失败: {str(e)}")
def embed(file_path: str, userid: str, db_type: str) -> bool:
"""
嵌入文件到 Milvus 向量数据库抽取三元组保存到指定路径并将三元组存储到 Neo4j
"""
try:
if not userid or not db_type:
raise ValueError("userid 和 db_type 不能为空")
if "_" in userid:
raise ValueError("userid 不能包含下划线")
if "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
ext = file_path.rsplit('.', 1)[1].lower()
if ext not in supported_formats:
logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}")
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
document_id = generate_document_id()
logger.info(f"生成 document_id: {document_id} for file: {file_path}")
logger.info(f"开始处理文件 {file_path}userid: {userid}db_type: {db_type}")
chunks = load_and_split_data(file_path, userid, document_id)
if not chunks:
logger.error(f"文件 {file_path} 未生成任何文档块")
raise ValueError("未生成任何文档块")
logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块")
logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}")
db = get_vector_db(userid, db_type, documents=chunks)
if not db:
logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}")
raise RuntimeError(f"数据库操作失败")
try:
full_text = fileloader(file_path)
if full_text and full_text.strip():
success = extract_and_save_triplets(full_text, document_id, userid)
triplet_file_path = f"/share/wangmeihua/rag/triples/{document_id}_{userid}.txt"
if success and os.path.exists(triplet_file_path):
logger.info(f"文件 {file_path} 三元组保存到: {triplet_file_path}")
try:
kg = KnowledgeGraph(data_path=triplet_file_path, document_id=document_id)
logger.info(f"Step 1: 导入图谱节点到 Neo4jdocument_id: {document_id}")
kg.create_graphnodes()
logger.info(f"Step 2: 导入图谱边到 Neo4jdocument_id: {document_id}")
kg.create_graphrels()
logger.info(f"Step 3: 导出 Neo4j 节点数据document_id: {document_id}")
kg.export_data()
logger.info(f"文件 {file_path} 三元组成功插入 Neo4j")
except Exception as e:
logger.warning(f"将三元组插入 Neo4j 失败: {str(e)},但不影响 Milvus 嵌入")
else:
logger.warning(f"文件 {file_path} 的三元组抽取失败或文件不存在: {triplet_file_path}")
else:
logger.warning(f"文件 {file_path} 内容为空,无法抽取三元组")
except Exception as e:
logger.error(f"文件 {file_path} 三元组抽取失败: {str(e)},但不影响向量化")
logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}")
return True
except ValueError as ve:
logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}")
return False
except RuntimeError as re:
logger.error(f"嵌入文件 {file_path} 失败: {str(re)}")
return False
except Exception as e:
logger.error(f"嵌入文件 {file_path} 失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return False
if __name__ == "__main__":
test_file = "/share/wangmeihua/rag/data/test.docx"
userid = "testuser1"
db_type = "textdb"
result = embed(test_file, userid, db_type)
print(f"嵌入结果: {result}")

225
rag/extract.py Normal file
View File

@ -0,0 +1,225 @@
import os
import torch
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging
import yaml
import time
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
# 三元组保存路径
TRIPLES_OUTPUT_DIR = "/share/wangmeihua/rag/triples"
os.makedirs(TRIPLES_OUTPUT_DIR, exist_ok=True)
# 加载 mREBEL 模型和分词器
local_path = "/share/models/Babelscape/mrebel-large"
try:
tokenizer = AutoTokenizer.from_pretrained(local_path, src_lang="zh_CN", tgt_lang="tp_XX")
model = AutoModelForSeq2SeqLM.from_pretrained(local_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
triplet_id = tokenizer.convert_tokens_to_ids("<triplet>")
logger.debug(f"成功加载 mREBEL 模型,分词器 triplet_id: {triplet_id}")
except Exception as e:
logger.error(f"加载 mREBEL 模型失败: {str(e)}")
raise RuntimeError(f"加载 mREBEL 模型失败: {str(e)}")
# 优化生成参数
gen_kwargs = {
"max_length": 512,
"min_length": 10,
"length_penalty": 0.5,
"num_beams": 3,
"num_return_sequences": 1,
"no_repeat_ngram_size": 2,
"early_stopping": True,
"decoder_start_token_id": triplet_id,
}
def split_document(text: str, max_chunk_size: int = 150) -> list:
"""分割文档为语义完整的块"""
sentences = re.split(r'(?<=[。!?;\n])', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= max_chunk_size:
current_chunk += sentence
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk)
return chunks
def extract_triplets_typed(text: str) -> list:
"""解析 mREBEL 生成文本,匹配 <triplet> <entity1> <type1> <entity2> <type2> <relation> 格式"""
triplets = []
logger.debug(f"原始生成文本: {text}")
# 分割标记
tokens = []
in_tag = False
buffer = ""
for char in text:
if char == '<':
in_tag = True
if buffer:
tokens.append(buffer.strip())
buffer = ""
buffer += char
elif char == '>':
in_tag = False
buffer += char
tokens.append(buffer.strip())
buffer = ""
else:
buffer += char
if buffer:
tokens.append(buffer.strip())
# 过滤特殊标记
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
tokens = [t for t in tokens if t not in special_tokens and t]
logger.debug(f"处理后标记: {tokens}")
# 解析三元组
i = 0
while i < len(tokens):
if tokens[i] == "<triplet>" and i + 5 < len(tokens):
entity1 = tokens[i + 1]
type1 = tokens[i + 2][1:-1] if tokens[i + 2].startswith("<") and tokens[i + 2].endswith(">") else ""
entity2 = tokens[i + 3]
type2 = tokens[i + 4][1:-1] if tokens[i + 4].startswith("<") and tokens[i + 4].endswith(">") else ""
relation = tokens[i + 5]
if entity1 and type1 and entity2 and type2 and relation:
triplets.append({
'head': entity1.strip(),
'head_type': type1,
'type': relation.strip(),
'tail': entity2.strip(),
'tail_type': type2
})
logger.debug(f"添加三元组: {entity1}({type1}) - {relation} - {entity2}({type2})")
i += 6
else:
i += 1
return triplets
def extract_and_save_triplets(text: str, document_id: str, userid: str) -> bool:
"""
从文本中抽取三元组并保存到指定路径
参数:
text (str): 输入文本
document_id (str): 文档ID
userid (str): 用户ID
返回:
bool: 三元组抽取和保存是否成功
"""
try:
if not text or not document_id or not userid:
raise ValueError("text、document_id 和 userid 不能为空")
if "_" in document_id or "_" in userid:
raise ValueError("document_id 和 userid 不能包含下划线")
start_time = time.time()
logger.info(f"开始抽取文档 {document_id} 的三元组userid: {userid}")
# 分割文本为语义块
text_chunks = split_document(text, max_chunk_size=150)
logger.debug(f"分割为 {len(text_chunks)} 个文本块")
# 处理所有文本块
all_triplets = []
for i, chunk in enumerate(text_chunks):
logger.debug(f"处理块 {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
# 分词
model_inputs = tokenizer(
chunk,
max_length=256,
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
# 生成
try:
generated_tokens = model.generate(
model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
**gen_kwargs,
)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
for idx, sentence in enumerate(decoded_preds):
logger.debug(f"{i + 1} 生成文本: {sentence}")
triplets = extract_triplets_typed(sentence)
if triplets:
logger.debug(f"{i + 1} 提取到 {len(triplets)} 个三元组")
all_triplets.extend(triplets)
except Exception as e:
logger.warning(f"处理块 {i + 1} 时出错: {str(e)}")
continue
# 去重
unique_triplets = []
seen = set()
for t in all_triplets:
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triplets.append(t)
# 保存结果
output_file = os.path.join(TRIPLES_OUTPUT_DIR, f"{document_id}_{userid}.txt")
try:
with open(output_file, "w", encoding="utf-8") as f:
for t in unique_triplets:
f.write(f"{t['head']}\t{t['type']}\t{t['tail']}\t{t['head_type']}\t{t['tail_type']}\n")
logger.info(f"文档 {document_id}{len(unique_triplets)} 个三元组已保存到: {output_file}")
except Exception as e:
logger.error(f"保存文档 {document_id} 的三元组失败: {str(e)}")
return False
end_time = time.time()
logger.info(f"文档 {document_id} 三元组抽取完成,耗时: {end_time - start_time:.2f}")
return True
except Exception as e:
logger.error(f"抽取或保存三元组失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return False
if __name__ == "__main__":
# 测试用例
test_text = "知识图谱是一个结构化的语义知识库。深度学习是基于深层神经网络的机器学习子集。"
document_id = "testdoc123"
userid = "testuser1"
result = extract_and_save_triplets(test_text, document_id, userid)
print(f"抽取结果: {result}")

290
rag/fusedsearch.py Normal file
View File

@ -0,0 +1,290 @@
import os
import logging
import yaml
import numpy as np
from typing import List, Dict, Any
from pymilvus import Collection, utility
from langchain_huggingface import HuggingFaceEmbeddings
from vector import initialize_milvus_connection
from searchquery import extract_entities, match_triplets
from rerank import rerank_results
import torch
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
logger.handlers.clear()
logger.propagate = False
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
try:
with open(config['logging']['file'], 'a', encoding='utf-8') as f:
pass
except Exception as e:
raise RuntimeError(f"日志文件 {config['logging']['file']} 不可写: {str(e)}")
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
# 初始化嵌入模型
embedding = HuggingFaceEmbeddings(
model_name=TEXT_EMBEDDING_MODEL,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
try:
test_vector = embedding.embed_query("test")
if len(test_vector) != 1024:
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
logger.debug("嵌入模型加载成功")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
def fused_search(
query: str,
userid: str,
db_type: str,
file_paths: List[str],
limit: int = 10,
offset: int = 0,
use_rerank: bool = True
) -> List[Dict[str, Any]]:
"""
融合 RAG 和三元组召回文本块
- 调用 searchquery.py extract_entities match_triplets 获取三元组
- 将所有匹配三元组拼接为融合文本向量化后在 Milvus 中搜索
参数:
query (str): 查询文本
userid (str): 用户 ID
db_type (str): 数据库类型 (e.g., 'textdb')
file_paths (List[str]): 文件路径列表
limit (int): 返回结果数量
offset (int): 偏移量
use_rerank (bool): 是否使用重排序
返回:
List[Dict[str, Any]]: 召回结果包含 textdistancemetadata
"""
try:
logger.info(f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}")
# 参数验证
if not query or not userid or not db_type or not file_paths:
raise ValueError("query、userid、db_type 和 file_paths 不能为空")
if "_" in userid or "_" in db_type:
raise ValueError("userid 和 db_type 不能包含下划线")
# 初始化 Milvus 连接
connections = initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return []
collection = Collection(collection_name)
collection.load()
logger.debug(f"加载 Milvus 集合: {collection_name}")
# 提取实体
query_entities = extract_entities(query)
logger.debug(f"提取实体: {query_entities}")
# 收集所有结果
results = []
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
for file_path in file_paths:
filename = os.path.basename(file_path)
logger.debug(f"处理文件: {filename}")
# 获取 document_id
results_query = collection.query(
expr=f"userid == '{userid}' and filename == '{filename}'",
output_fields=["document_id"],
limit=1
)
if not results_query:
logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档")
continue
document_id = results_query[0]["document_id"]
logger.debug(f"找到 document_id: {document_id}")
# 获取匹配的三元组
matched_triplets = match_triplets(query, query_entities, userid, document_id)
logger.debug(f"匹配三元组: {matched_triplets}")
# 若无三元组,使用原查询向量化
if not matched_triplets:
logger.debug(f"无匹配三元组,使用原查询: {query}")
query_vector = embedding.embed_query(query)
expr = f"userid == '{userid}' and filename == '{filename}'"
milvus_results = collection.search(
data=[query_vector],
anns_field="vector",
param=search_params,
limit=limit,
expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
offset=offset
)
for hits in milvus_results:
for hit in hits:
result = {
"text": hit.entity.get("text"),
"distance": hit.distance,
"source": "fused_query",
"metadata": {
"userid": hit.entity.get("userid"),
"document_id": hit.entity.get("document_id"),
"filename": hit.entity.get("filename"),
"file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type")
}
}
results.append(result)
logger.debug(f"召回: text={result['text'][:100]}..., distance={result['distance']}")
continue
# 拼接所有三元组
triplet_texts = []
for triplet in matched_triplets:
head = triplet['head']
type = triplet['type']
tail = triplet['tail']
if not head or not type or not tail:
logger.debug(f"无效三元组: {triplet}")
continue
triplet_texts.append(f"{head} {type} {tail}")
if not triplet_texts:
logger.debug(f"无有效三元组,使用原查询: {query}")
query_vector = embedding.embed_query(query)
expr = f"userid == '{userid}' and filename == '{filename}'"
milvus_results = collection.search(
data=[query_vector],
anns_field="vector",
param=search_params,
limit=5,
expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
offset=offset
)
for hits in milvus_results:
for hit in hits:
result = {
"text": hit.entity.get("text"),
"distance": hit.distance,
"source": "fused_query",
"metadata": {
"userid": hit.entity.get("userid"),
"document_id": hit.entity.get("document_id"),
"filename": hit.entity.get("filename"),
"file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type")
}
}
results.append(result)
logger.debug(f"召回: text={result['text'][:100]}..., distance={result['distance']}")
continue
# 生成融合文本
fused_text = f"{query} {' '.join(triplet_texts)}"
logger.debug(f"融合文本: {fused_text}")
# 向量化
fused_vector = embedding.embed_query(fused_text)
fused_vector = np.array(fused_vector) / np.linalg.norm(fused_vector)
logger.debug(f"生成融合向量,维度: {len(fused_vector)}")
# Milvus 搜索
expr = f"userid == '{userid}' and filename == '{filename}'"
milvus_results = collection.search(
data=[fused_vector],
anns_field="vector",
param=search_params,
limit=5,
expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
offset=offset
)
for hits in milvus_results:
for hit in hits:
result = {
"text": hit.entity.get("text"),
"distance": hit.distance,
"source": f"fused_triplets_{len(triplet_texts)}",
"metadata": {
"userid": hit.entity.get("userid"),
"document_id": hit.entity.get("document_id"),
"filename": hit.entity.get("filename"),
"file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type")
}
}
results.append(result)
logger.debug(f"召回: text={result['text'][:100]}..., distance={result['distance']}")
# 去重
unique_results = []
seen_texts = set()
for result in results:
text = result['text']
if text not in seen_texts:
seen_texts.add(text)
unique_results.append(result)
logger.debug(f"去重后结果数量: {len(unique_results)}")
# 可选:重排序
if use_rerank and unique_results:
logger.debug("开始重排序")
reranked_results = rerank_results(query, unique_results)
# 按 rerank_score 降序排序
reranked_results = sorted(reranked_results, key=lambda x: x['rerank_score'], reverse=True)
for i, result in enumerate(reranked_results):
logger.debug(f"排序结果 {i+1}: text={result['text'][:100]}..., distance={result['distance']}, rerank_score={result['rerank_score']}")
return reranked_results[:limit]
# 按 distance 降序排序
sorted_results = sorted(unique_results, key=lambda x: x['distance'], reverse=True)
for i, result in enumerate(sorted_results):
logger.debug(f"排序结果 {i+1}: text={result['text'][:100]}..., distance={result['distance']}")
return sorted_results[:limit]
except Exception as e:
logger.error(f"融合搜索失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return []
if __name__ == "__main__":
query = "知识图谱构建需要什么技术?"
userid = "testuser1"
db_type = "textdb"
file_paths = [
"/share/wangmeihua/rag/data/test.docx",
"/share/wangmeihua/rag/data/zongshu.pdf",
"/share/wangmeihua/rag/data/qianru.pdf"
]
results = fused_search(query, userid, db_type, file_paths, limit=10, offset=0)
for i, result in enumerate(results):
print(f"Result {i+1}:")
print(f"Text: {result['text'][:200]}...")
print(f"Distance: {result['distance']}")
print(f"Source: {result['source']}")
print(f"Metadata: {result['metadata']}\n")

15
rag/init.py Normal file
View File

@ -0,0 +1,15 @@
from appPublic.worker import awaitify
from ahserver.serverenv import ServerEnv
from .kdb import add_kdb, add_dir, add_doc, get_all_docs
from .query import search_query
from .embed import embed
def load_rag():
env = ServerEnv()
env.add_kdb = add_kdb
env.query = awaitify(search_query)
env.embed = awaitify(embed)
env.add_dir = add_dir
env.add_doc = add_doc
env.get_all_docs = get_all_docs

81
rag/kdb.py Normal file
View File

@ -0,0 +1,81 @@
from traceback import format_exc
from appPublic.uniqueID import getID
from appPublic.timeUtils import curDateString
from appPublic.dictObject import DictObject
from sqlor.dbpools import DBPools
from ahserver.serverenv import get_serverenv
from ahserver.filestorage import FileStorage
async def add_kdb(kdb:dict) -> None:
"""
添加知识库
"""
kdb = DictObject(**kdb)
kdb.parentid=None
if kdb.id is None:
kdb.id = getID()
kdb.entity_type = '0'
kdb.create_date = curDateString()
if kdb.orgid is None:
e = Exception(f'Can not add none orgid kdb')
exception(f'{e}\n{format_exc()}')
raise e
f = get_serverenv('get_module_dbname')
dbname = f('rag')
db = DBPools()
async with db.sqlorContext(dbname) as sor:
await C('kdb', kdb.copy())
async def add_dir(kdb:dict) -> None:
"""
添加子目录
"""
kdb = DictObject(**kdb)
if kdb.parentid is None:
e = Exception(f'Can not add root folder')
exception(f'{e}\n{format_exc()}')
raise e
if kdb.id is None:
kdb.id = getID()
kdb.entity_type = '1'
kdb.create_date = curDateString()
f = get_serverenv('get_module_dbname')
dbname = f('rag')
db = DBPools()
async with db.sqlorContext(dbname) as sor:
await C('kdb', kdb.copy())
async def add_doc(doc:dict) -> None:
"""
添加文档
"""
doc = DictObject(**doc)
if doc.parentid is None:
e = Exception(f'Can not add root document')
exception(f'{e}\n{format_exc()}')
raise e
if doc.id is None:
doc.id = getID()
fs = FileStorage()
doc.realpath = fs.realPath(doc.webpath)
doc.create_date = curDateString()
f = get_serverenv('get_module_dbname')
dbname = f('rag')
db = DBPools()
async with db.sqlorContext(dbname) as sor:
await C('doc', doc.copy())
async def get_all_docs(sor, kdbid):
"""
获取所有kdbid下的文档含子目录的
"""
docs = await sor.R('doc', {'parentid':kdbid})
kdbs = await sor.R('kdb', {'parentid':kdbid})
for kdb in kdbs:
docs1 = await get_all_docs(kdb.id)
docs += docs1
return docs

194
rag/kgc.py Normal file
View File

@ -0,0 +1,194 @@
import os
import logging
import re
from py2neo import Graph, Node, Relationship
from typing import Set, List, Dict, Tuple
from ufw.common import share_dir
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class KnowledgeGraph:
def __init__(self, data_path: str, document_id: str = None):
self.data_path = data_path
self.document_id = document_id or os.path.basename(data_path).split('_')[0]
self.g = Graph("bolt://10.18.34.18:7687", auth=('neo4j', '261229..wmh'))
logger.info(f"开始构建知识图谱data_path: {self.data_path}, document_id: {self.document_id}")
# 验证 data_path 是否有效
if not os.path.exists(self.data_path):
logger.error(f"数据路径 {self.data_path} 不存在")
raise ValueError(f"数据路径 {self.data_path} 不存在")
def _normalize_label(self, entity_type: str) -> str:
"""规范化实体类型为 Neo4j 标签"""
if not entity_type or not entity_type.strip():
return 'Entity'
entity_type = re.sub(r'[^\w\s]', '', entity_type.strip())
words = entity_type.split()
label = '_'.join(word.capitalize() for word in words if word)
return label or 'Entity'
def _clean_relation(self, relation: str) -> Tuple[str, str]:
"""清洗关系,返回 (rel_type, rel_name)"""
relation = relation.strip()
if not relation:
return 'RELATED_TO', '相关'
if relation.startswith('<') and relation.endswith('>'):
cleaned_relation = relation[1:-1]
rel_name = cleaned_relation
rel_type = re.sub(r'[^\w\s]', '', cleaned_relation).replace(' ', '_').upper()
else:
rel_name = relation
rel_type = re.sub(r'[^\w\s]', '', relation).replace(' ', '_').upper()
if 'instance of' in relation.lower():
rel_type = 'INSTANCE_OF'
rel_name = '实例'
elif 'subclass of' in relation.lower():
rel_type = 'SUBCLASS_OF'
rel_name = '子类'
elif 'part of' in relation.lower():
rel_type = 'PART_OF'
rel_name = '部分'
logger.debug(f"处理关系: {relation} -> {rel_type} ({rel_name})")
return rel_type, rel_name
def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]:
"""读取三元组数据,返回节点和关系"""
nodes_by_label = {}
relations_by_type = {}
triples = []
try:
logger.debug(f"尝试读取文件: {self.data_path}")
with open(self.data_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line or line.startswith('#'):
continue
parts = line.split('\t')
if len(parts) != 5:
logger.warning(f"无效行: {line}")
continue
head, relation, tail, head_type, tail_type = parts
head_label = self._normalize_label(head_type)
tail_label = self._normalize_label(tail_type)
logger.debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}")
if head_label not in nodes_by_label:
nodes_by_label[head_label] = set()
if tail_label not in nodes_by_label:
nodes_by_label[tail_label] = set()
nodes_by_label[head_label].add(head)
nodes_by_label[tail_label].add(tail)
rel_type, rel_name = self._clean_relation(relation)
if rel_type not in relations_by_type:
relations_by_type[rel_type] = []
relations_by_type[rel_type].append({
'head': head,
'tail': tail,
'head_label': head_label,
'tail_label': tail_label,
'rel_name': rel_name
})
triples.append({
'head': head,
'relation': relation,
'tail': tail,
'head_type': head_type,
'tail_type': tail_type
})
logger.info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())}")
logger.info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())}")
return nodes_by_label, relations_by_type, triples
except Exception as e:
logger.error(f"读取数据失败: {str(e)}data_path: {self.data_path}")
raise RuntimeError(f"读取数据失败: {str(e)}")
def create_node(self, label: str, nodes: Set[str]):
"""创建节点,包含 document_id 属性"""
count = 0
for node_name in nodes:
query = f"MATCH (n:{label} {{name: '{node_name}', document_id: '{self.document_id}'}}) RETURN n"
try:
if self.g.run(query).data():
continue
node = Node(label, name=node_name, document_id=self.document_id)
self.g.create(node)
count += 1
logger.debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id})")
except Exception as e:
logger.error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}")
logger.info(f"创建 {label} 节点: {count}/{len(nodes)}")
return count
def create_relationship(self, rel_type: str, relations: List[Dict]):
"""创建关系"""
count = 0
total = len(relations)
seen_edges = set()
for rel in relations:
head, tail, head_label, tail_label, rel_name = (
rel['head'], rel['tail'], rel['head_label'], rel['tail_label'], rel['rel_name']
)
edge_key = f"{head_label}:{head}###{tail_label}:{tail}###{rel_type}"
if edge_key in seen_edges:
continue
seen_edges.add(edge_key)
query = (
f"MATCH (p:{head_label} {{name: '{head}', document_id: '{self.document_id}'}}), "
f"(q:{tail_label} {{name: '{tail}', document_id: '{self.document_id}'}}) "
f"CREATE (p)-[r:{rel_type} {{name: '{rel_name}'}}]->(q)"
)
try:
self.g.run(query)
count += 1
logger.debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id})")
except Exception as e:
logger.error(f"创建关系失败: {query}, 错误: {str(e)}")
logger.info(f"创建 {rel_type} 关系: {count}/{total}")
return count
def create_graphnodes(self):
"""创建所有节点"""
nodes_by_label, _, _ = self.read_nodes()
total = 0
for label, nodes in nodes_by_label.items():
total += self.create_node(label, nodes)
logger.info(f"总计创建节点: {total}")
return total
def create_graphrels(self):
"""创建所有关系"""
_, relations_by_type, _ = self.read_nodes()
total = 0
for rel_type, relations in relations_by_type.items():
total += self.create_relationship(rel_type, relations)
logger.info(f"总计创建关系: {total}")
return total
def export_data(self):
"""导出节点到文件,包含 document_id"""
nodes_by_label, _, _ = self.read_nodes()
os.makedirs('dict', exist_ok=True)
for label, nodes in nodes_by_label.items():
with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f:
f.write('\n'.join(f"{name}\t{self.document_id}" for name in sorted(nodes)))
logger.info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)}")
return
if __name__ == '__main__':
data_path = '/share/wangmeihua/rag/triples/26911c68-9107-4bb4-8f31-ff776991a119_testuser2.txt'
handler = KnowledgeGraph(data_path)
logger.info("Step 1: 导入图谱节点中")
handler.create_graphnodes()
logger.info("Step 2: 导入图谱边中")
handler.create_graphrels()
logger.info("Step 3: 导出数据")
handler.export_data()

201
rag/query.py Normal file
View File

@ -0,0 +1,201 @@
import os
import yaml
import logging
from typing import List, Dict
from pymilvus import connections, Collection, utility
from langchain_huggingface import HuggingFaceEmbeddings
from vector import initialize_milvus_connection, cleanup_milvus_connection
import torch
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
logger.handlers.clear() # 清除现有处理器,避免重复
logger.propagate = False # 禁用传播到父级
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
def search_query(query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, offset: int = 0) -> List[Dict]:
"""
根据用户输入的查询文本在指定 db_type 的知识库中搜索与 userid 相关的指定文档
参数:
query (str): 用户输入的查询文本
userid (str): 用户ID用于过滤
db_type (str): 数据库类型例如 'textdb'
file_paths (List[str]): 文档路径列表支持1到多个文件
limit (int): 返回的最大结果数默认为 10
offset (int): 偏移量用于分页默认为 0
返回:
List[Dict]: 搜索结果每个元素为包含 textdistance metadata 的字典
异常:
ValueError: 参数无效
RuntimeError: 模型加载或 Milvus 操作失败
"""
try:
# 参数验证
if not query:
raise ValueError("查询文本不能为空")
if not userid or not db_type:
raise ValueError("userid 和 db_type 不能为空")
if "_" in userid or "_" in db_type:
raise ValueError("userid 和 db_type 不能包含下划线")
if len(userid) > 100 or len(db_type) > 100:
raise ValueError("userid 或 db_type 的长度超出限制")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if offset < 0:
raise ValueError("offset 不能为负数")
if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384")
if not file_paths:
raise ValueError("file_paths 不能为空")
for file_path in file_paths:
if not isinstance(file_path, str):
raise ValueError(f"file_path 必须是字符串: {file_path}")
if len(os.path.basename(file_path)) > 255:
raise ValueError(f"文件名长度超出 255 个字符: {file_path}")
if "_" in os.path.basename(file_path):
raise ValueError(f"文件名 {file_path} 不能包含下划线")
# 初始化嵌入模型
model_path = TEXT_EMBEDDING_MODEL
if not os.path.exists(model_path):
raise ValueError(f"模型路径 {model_path} 不存在")
embedding = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
try:
test_vector = embedding.embed_query("test")
if len(test_vector) != 1024:
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
logger.debug("嵌入模型加载成功")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
# 将查询转换为向量
query_vector = embedding.embed_query(query)
logger.debug(f"查询向量维度: {len(query_vector)}")
# 连接到 Milvus
initialize_milvus_connection()
# 检查集合是否存在
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return []
# 加载集合
try:
collection = Collection(collection_name)
collection.load()
logger.debug(f"加载集合: {collection_name}")
except Exception as e:
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"加载集合失败: {str(e)}")
# 构造搜索参数
search_params = {
"metric_type": "COSINE", # 与 vector.py 一致
"params": {"nprobe": 10} # 优化搜索性能
}
# 构造过滤表达式,限制在指定文件
filenames = [os.path.basename(file_path) for file_path in file_paths]
filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames])
expr = f"userid == '{userid}' and ({filename_expr})"
logger.debug(f"搜索参数: {search_params}, 表达式: {expr}, limit: {limit}, offset: {offset}")
# 执行搜索
try:
results = collection.search(
data=[query_vector],
anns_field="vector",
param=search_params,
limit=limit,
expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
offset=offset
)
except Exception as e:
logger.error(f"搜索失败: {str(e)}")
raise RuntimeError(f"搜索失败: {str(e)}")
# 处理搜索结果
search_results = []
for hits in results:
for hit in hits:
metadata = {
"userid": hit.entity.get("userid"),
"document_id": hit.entity.get("document_id"),
"filename": hit.entity.get("filename"),
"file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type")
}
result = {
"text": hit.entity.get("text"),
"distance": hit.distance,
"metadata": metadata
}
search_results.append(result)
logger.debug(f"命中: text: {result['text'][:200]}..., 距离: {hit.distance}, 元数据: {metadata}")
logger.debug(f"搜索完成,返回 {len(search_results)} 条结果")
return search_results
except Exception as e:
logger.error(f"搜索失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
raise
finally:
cleanup_milvus_connection()
if __name__ == "__main__":
# 测试代码
query = "知识图谱的知识融合是什么?"
userid = "testuser2"
db_type = "textdb"
file_paths = [
"/share/wangmeihua/rag/data/test.docx",
"/share/wangmeihua/rag/data/test.txt"
]
limit = 5
offset = 0
try:
results = search_query(query, userid, db_type, file_paths, limit, offset)
print(f"搜索结果 ({len(results)} 条):")
for idx, result in enumerate(results, 1):
print(f"结果 {idx}:")
print(f"内容: {result['text'][:200]}...")
print(f"距离: {result['distance']}")
print(f"元数据: {result['metadata']}")
print("-" * 50)
except Exception as e:
print(f"搜索失败: {str(e)}")

53
rag/rag.bak.py Normal file
View File

@ -0,0 +1,53 @@
from ahserver.serverenv import ServerEnv
from ahserver.configuredServer import ConfiguredServer
from ahserver.webapp import webapp
from appPublic.worker import awaitify
from query import search_query
from embed import embed
from deletefile import delete_document
import logging
import os
# 配置日志
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(os.path.expanduser('~/rag/logs/rag.log'), encoding='utf-8'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def get_module_dbname(name):
"""
获取默认数据库名称优先使用环境变量 RAG_DB_TYPE未设置时返回 'textdb'
"""
return os.getenv('RAG_DB_TYPE', 'textdb')
def init():
"""
初始化 RAG 系统绑定核心函数到 ServerEnv
"""
try:
logger.info("初始化 RAG 系统")
g = ServerEnv()
# 绑定核心函数为异步版本
g.embed = awaitify(embed)
g.query = awaitify(search_query)
g.delete = awaitify(delete_document)
g.get_module_dbname = get_module_dbname
logger.info("RAG 系统初始化完成")
return g
except Exception as e:
logger.error(f"初始化失败: {str(e)}")
raise
if __name__ == '__main__':
logger.info("启动 RAG Web 服务器")
webapp(init)

80
rag/rerank.py Normal file
View File

@ -0,0 +1,80 @@
import os
import yaml
import logging
from typing import List, Dict
from pymilvus.model.reranker import BGERerankFunction
import torch
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
logger.handlers.clear() # 清除现有处理器
logger.propagate = False # 禁用传播
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
def rerank_results(query: str, results: List[Dict], top_k: int = 10) -> List[Dict]:
"""
使用 BGE Reranker 模型对查询和文本块进行重排序
参数:
query (str): 查询文本
results (List[Dict]): 包含 textdistancesource metadata 的结果列表
top_k (int): 返回的最大结果数默认为 10
返回:
List[Dict]: 重排序后的结果列表包含 textdistancesourcemetadata rerank_score
"""
try:
# 初始化 BGE Reranker
bge_rf = BGERerankFunction(
model_name="/share/models/BAAI/bge-reranker-v2-m3",
device="cuda:0" if torch.cuda.is_available() else "cpu"
)
logger.debug(f"BGE Reranker 初始化成功,模型路径: /share/models/BAAI/bge-reranker-v2-m3, 设备: {'cuda:0' if torch.cuda.is_available() else 'cpu'}")
# 提取文本块
documents = [result['text'] for result in results]
if not documents:
logger.warning("无文本块可重排序")
return results
# 重排序
rerank_results = bge_rf(
query=query,
documents=documents,
top_k=min(top_k, len(documents))
)
# 构建重排序结果
reranked = []
for result in rerank_results:
original_result = results[result.index].copy()
original_result['rerank_score'] = result.score
reranked.append(original_result)
logger.debug(f"重排序结果: text={result.text[:200]}..., rerank_score={result.score:.6f}, source={original_result['source']}")
logger.info(f"重排序返回 {len(reranked)} 条结果")
return reranked
except Exception as e:
logger.error(f"重排序失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
# 回退到原始结果
return results

363
rag/searchquery.py Normal file
View File

@ -0,0 +1,363 @@
import os
import yaml
import logging
from typing import List, Dict
from pymilvus import connections, Collection, utility
from langchain_huggingface import HuggingFaceEmbeddings
import numpy as np
from scipy.spatial.distance import cosine
from ltp import LTP
from vector import initialize_milvus_connection, cleanup_milvus_connection
import torch
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"加载配置文件失败: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
logger.handlers.clear() # 清理现有处理器,避免重复
logger.propagate = False # 禁用传播到父级
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
# 三元组保存路径
TRIPLES_OUTPUT_DIR = '/share/wangmeihua/rag/triples'
# 初始化嵌入模型
embedding = HuggingFaceEmbeddings(
model_name=TEXT_EMBEDDING_MODEL,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
try:
test_vector = embedding.embed_query("test")
if len(test_vector) != 1024:
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
logger.debug("嵌入模型加载成功")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
# 初始化 LTP 模型
try:
model_path = "/share/models/LTP/small"
if not os.path.isdir(model_path):
logger.warning(f"本地模型路径 {model_path} 不存在,尝试使用 Hugging Face 模型 'hit-scir/ltp-small'")
model_path = "hit-scir/ltp-small"
ltp = LTP(pretrained_model_name_or_path=model_path)
if torch.cuda.is_available():
ltp.to("cuda")
logger.debug("LTP 模型加载成功")
except Exception as e:
logger.error(f"加载 LTP 模型失败: {str(e)}")
raise RuntimeError(f"加载 LTP 模型失败: {str(e)}")
def extract_entities(query: str) -> List[str]:
"""
从查询文本中抽取实体包括
- LTP NER 识别的实体所有类型
- LTP POS 标注为名词'n'的词
- LTP POS 标注为动词'v'的词
- 连续名词合并 '苹果 公司' -> '苹果公司'移除子词
"""
try:
if not query:
raise ValueError("查询文本不能为空")
# 使用 LTP pipeline 获取分词、词性、NER 结果
result = ltp.pipeline([query], tasks=["cws", "pos", "ner"])
words = result.cws[0]
pos_list = result.pos[0]
ner = result.ner[0]
entities = []
subword_set = set() # 记录连续名词的子词
# 提取 1NER 实体(所有类型)
logger.debug(f"NER 结果: {ner}")
for entity_type, entity, start, end in ner:
entities.append(entity)
# 提取 2合并连续名词
combined = ""
combined_words = [] # 记录当前连续名词的单词
for i in range(len(words)):
if pos_list[i] == 'n':
combined += words[i]
combined_words.append(words[i])
if i + 1 < len(words) and pos_list[i + 1] == 'n':
continue
if combined:
entities.append(combined)
subword_set.update(combined_words)
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}")
combined = ""
combined_words = []
else:
combined = ""
combined_words = []
logger.debug(f"连续名词子词集合: {subword_set}")
# 提取 3POS 名词('n'),排除子词
for word, pos in zip(words, pos_list):
if pos == 'n' and word not in subword_set:
entities.append(word)
# 提取 4POS 动词('v'
for word, pos in zip(words, pos_list):
if pos == 'v':
entities.append(word)
# 去重
unique_entities = list(dict.fromkeys(entities))
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
return unique_entities
except Exception as e:
logger.error(f"实体抽取失败: {str(e)}")
return []
def load_triplets_from_file(triplet_file: str) -> List[Dict]:
"""从三元组文件中加载"""
triplets = []
try:
if not os.path.exists(triplet_file):
logger.warning(f"三元组文件 {triplet_file} 不存在")
return []
with open(triplet_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
parts = line.strip().split('\t')
if len(parts) >= 5:
head, relation, tail, head_type, tail_type = parts[:5]
triplets.append({
'head': head,
'head_type': head_type,
'type': relation,
'tail': tail,
'tail_type': tail_type
})
logger.debug(f"{triplet_file} 加载 {len(triplets)} 个三元组")
return triplets
except Exception as e:
logger.error(f"加载三元组文件 {triplet_file} 失败: {str(e)}")
return []
def match_triplets(query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]:
"""
匹配查询实体与文档三元组使用语义嵌入
- 初始匹配实体与 head tail 相似度 0.8
- 返回匹配的三元组
"""
matched_triplets = []
ENTITY_SIMILARITY_THRESHOLD = 0.8 # 实体与 head/tail 相似度阈值
try:
# 加载三元组
triplet_file = os.path.join(TRIPLES_OUTPUT_DIR, f"{document_id}_{userid}.txt")
doc_triplets = load_triplets_from_file(triplet_file)
if not doc_triplets:
logger.debug(f"文档 document_id={document_id} 无三元组")
return []
# 缓存查询实体嵌入
entity_vectors = {entity: embedding.embed_query(entity) for entity in query_entities}
# 初始匹配
for entity in query_entities:
entity_vec = entity_vectors[entity]
for d_triplet in doc_triplets:
d_head_vec = embedding.embed_query(d_triplet['head'])
d_tail_vec = embedding.embed_query(d_triplet['tail'])
head_similarity = 1 - cosine(entity_vec, d_head_vec)
tail_similarity = 1 - cosine(entity_vec, d_tail_vec)
if head_similarity >= ENTITY_SIMILARITY_THRESHOLD or tail_similarity >= ENTITY_SIMILARITY_THRESHOLD:
matched_triplets.append(d_triplet)
logger.debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} "
f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})")
# 去重
unique_matched = []
seen = set()
for t in matched_triplets:
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
if identifier not in seen:
seen.add(identifier)
unique_matched.append(t)
logger.info(f"找到 {len(unique_matched)} 个匹配的三元组")
return unique_matched
except Exception as e:
logger.error(f"匹配三元组失败: {str(e)}")
return []
def searchquery(query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, offset: int = 0) -> List[Dict]:
"""
根据查询抽取实体匹配指定文档的三元组并在 Milvus 中搜索相关文档片段
"""
try:
if not query or not userid or not db_type or not file_paths:
raise ValueError("query、userid、db_type 和 file_paths 不能为空")
if "_" in userid or "_" in db_type:
raise ValueError("userid 和 db_type 不能包含下划线")
if len(userid) > 100 or len(db_type) > 100:
raise ValueError("userid 或 db_type 的长度超出限制")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if offset < 0:
raise ValueError("offset 不能为负数")
if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384")
initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return []
collection = Collection(collection_name)
collection.load()
documents = []
for file_path in file_paths:
filename = os.path.basename(file_path)
results = collection.query(
expr=f"userid == '{userid}' and filename == '{filename}'",
output_fields=["document_id", "filename"],
limit=1
)
if not results:
logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档")
continue
documents.append(results[0])
if not documents:
logger.warning("没有找到任何有效文档")
return []
logger.info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}")
query_entities = extract_entities(query)
if not query_entities:
logger.warning("未从查询中提取到实体")
return []
search_results = []
for doc in documents:
document_id = doc["document_id"]
filename = doc["filename"]
logger.debug(f"处理文档: document_id={document_id}, filename={filename}")
matched_triplets = match_triplets(query, query_entities, userid, document_id)
if not matched_triplets:
logger.debug(f"文档 document_id={document_id} 未找到匹配的三元组")
continue
for triplet in matched_triplets:
head = triplet['head']
type = triplet['type']
tail = triplet['tail']
if not head or not type or not tail:
logger.debug(f"无效三元组: head={head}, type={type}, tail={tail}")
continue
triplet_text = f"{head} {type} {tail}"
logger.debug(f"搜索三元组: {triplet_text} (文档: {filename})")
try:
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
query_vector = embedding.embed_query(triplet_text)
expr = f"userid == '{userid}' and filename == '{filename}' and text like '%{head}%{tail}%'"
logger.debug(f"搜索表达式: {expr}")
results = collection.search(
data=[query_vector],
anns_field="vector",
param=search_params,
limit=limit,
expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
offset=offset
)
for hits in results:
for hit in hits:
metadata = {
"userid": hit.entity.get("userid"),
"document_id": hit.entity.get("document_id"),
"filename": hit.entity.get("filename"),
"file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type")
}
result = {
"text": hit.entity.get("text"),
"distance": hit.distance,
"metadata": metadata
}
search_results.append(result)
logger.debug(f"命中: text: {result['text'][:200]}..., 距离: {hit.distance}, 元数据: {metadata}")
except Exception as e:
logger.warning(f"三元组 {triplet_text} 在文档 {filename} 搜索失败: {str(e)}")
continue
unique_results = []
seen_texts = set()
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
if result['text'] not in seen_texts:
unique_results.append(result)
seen_texts.add(result['text'])
if len(unique_results) >= limit:
break
logger.info(f"返回 {len(unique_results)} 条唯一结果")
return unique_results
except Exception as e:
logger.error(f"搜索失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return []
finally:
cleanup_milvus_connection()
if __name__ == "__main__":
query = "什么是知识图谱的知识抽取?"
userid = "testuser1"
db_type = "textdb"
file_paths = [
"/share/wangmeihua/rag/data/test.docx",
"/share/wangmeihua/rag/data/zongshu.pdf",
"/share/wangmeihua/rag/data/qianru.pdf"
]
limit = 5
offset = 0
try:
results = searchquery(query, userid, db_type, file_paths, limit, offset)
print(f"搜索结果 ({len(results)} 条):")
for idx, result in enumerate(results, 1):
print(f"结果 {idx}:")
print(f"内容: {result['text'][:200]}...")
print(f"距离: {result['distance']}")
print(f"元数据: {result['metadata']}")
print("-" * 50)
except Exception as e:
print(f"搜索失败: {str(e)}")

9
rag/test.py Normal file
View File

@ -0,0 +1,9 @@
from py2neo import Graph,Node,Relationship,NodeMatcher
username = 'neo4j'
password = '261229..wmh'
auth = (username, password)
graph=Graph("bolt://10.18.34.18:7687", auth = auth)
book_node=Node('经名',name='十三经')
graph.create(book_node)

539
rag/vector.py Normal file
View File

@ -0,0 +1,539 @@
import os
import uuid
import json
import yaml
from datetime import datetime
from typing import List, Dict, Optional
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
from langchain_milvus import Milvus
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
import torch
import logging
import time
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
def ensure_milvus_directory() -> None:
"""确保 Milvus 数据库目录存在"""
db_dir = os.path.dirname(MILVUS_DB_PATH)
if not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
logger.debug(f"创建 Milvus 目录: {db_dir}")
if not os.access(db_dir, os.W_OK):
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
def initialize_milvus_connection() -> None:
"""初始化 Milvus 连接,确保单一连接"""
try:
if not connections.has_connection("default"):
connections.connect("default", uri=MILVUS_DB_PATH)
logger.debug(f"已连接到 Milvus Lite路径: {MILVUS_DB_PATH}")
else:
logger.debug("已存在 Milvus 连接,跳过重复连接")
except Exception as e:
logger.error(f"连接 Milvus 失败: {str(e)}")
raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
def cleanup_milvus_connection() -> None:
"""清理 Milvus 连接,确保资源释放"""
try:
if connections.has_connection("default"):
connections.disconnect("default")
logger.debug("已断开 Milvus 连接")
time.sleep(3)
except Exception as e:
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
def get_vector_db(userid: str, db_type: str, documents: List[Document]) -> Milvus:
"""
初始化或访问 Milvus Lite 向量数据库集合 db_type 组织利用 userid 区分用户document_id 区分文档并插入文档
"""
try:
# 参数验证
if not userid or not db_type:
raise ValueError("userid 和 db_type 不能为空")
if "_" in userid or "_" in db_type:
raise ValueError("userid 和 db_type 不能包含下划线")
if len(userid) > 100 or len(db_type) > 100:
raise ValueError("userid 和 db_type 的长度应小于 100")
if not documents or not all(isinstance(doc, Document) for doc in documents):
raise ValueError("documents 不能为空且必须是 Document 对象列表")
required_fields = ["userid", "document_id", "filename", "file_path", "upload_time", "file_type"]
for doc in documents:
if not all(field in doc.metadata and doc.metadata[field] for field in required_fields):
raise ValueError(f"文档元数据缺少必需字段或字段值为空: {doc.metadata}")
if doc.metadata["userid"] != userid:
raise ValueError(f"文档元数据的 userid {doc.metadata['userid']} 与输入 userid {userid} 不一致")
ensure_milvus_directory()
initialize_milvus_connection()
# 初始化嵌入模型
model_path = TEXT_EMBEDDING_MODEL
if not os.path.exists(model_path):
raise ValueError(f"模型路径 {model_path} 不存在")
embedding = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
try:
test_vector = embedding.embed_query("test")
if len(test_vector) != 1024:
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
logger.debug(f"嵌入模型加载成功,输出维度: {len(test_vector)}")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"加载模型失败: {str(e)}")
# 集合名称
collection_name = f"ragdb_{db_type}"
if len(collection_name) > 255:
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
logger.debug(f"集合名称: {collection_name}")
# 定义 schema包含所有固定字段
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True),
FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100),
FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255),
FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64),
]
schema = CollectionSchema(
fields=fields,
description=f"{db_type} 数据集合,跨用户使用,包含 document_id 和元数据字段",
auto_id=True,
primary_field="pk",
)
# 检查集合是否存在
if utility.has_collection(collection_name):
try:
collection = Collection(collection_name)
existing_schema = collection.schema
expected_fields = {f.name for f in fields}
actual_fields = {f.name for f in existing_schema.fields}
vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None)
schema_compatible = False
if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR:
dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None
schema_compatible = dim == 1024
logger.debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, "
f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else ''}, "
f"dim={dim if dim is not None else '未定义'}")
if not schema_compatible:
logger.warning(f"集合 {collection_name} 的 schema 不兼容,原因: "
f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or ''}, "
f"vector_field: {vector_field is not None}, "
f"dtype: {vector_field.dtype if vector_field else ''}, "
f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}")
utility.drop_collection(collection_name)
else:
collection.load()
logger.debug(f"集合 {collection_name} 已存在并加载成功")
except Exception as e:
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"加载集合失败: {str(e)}")
# 创建新集合
if not utility.has_collection(collection_name):
try:
collection = Collection(collection_name, schema)
collection.create_index(
field_name="vector",
index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"}
)
collection.create_index(
field_name="userid",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="document_id",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="filename",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="file_path",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="upload_time",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="file_type",
index_params={"index_type": "INVERTED"}
)
collection.load()
logger.debug(f"成功创建并加载集合: {collection_name}")
except Exception as e:
logger.error(f"创建集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"创建集合失败: {str(e)}")
# 初始化 Milvus 向量存储
try:
vector_store = Milvus(
embedding_function=embedding,
collection_name=collection_name,
connection_args={"uri": MILVUS_DB_PATH},
drop_old=False,
auto_id=True,
primary_field="pk",
)
logger.debug(f"成功初始化 Milvus 向量存储: {collection_name}")
except Exception as e:
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
raise RuntimeError(f"初始化向量存储失败: {str(e)}")
# 插入文档
try:
logger.debug(f"正在为 userid {userid} 插入 {len(documents)} 个文档到 {collection_name}")
for doc in documents:
logger.debug(f"插入文档元数据: {doc.metadata}")
vector_store.add_documents(documents=documents)
logger.debug(f"成功插入 {len(documents)} 个文档")
# 立即查询验证
collection = Collection(collection_name)
collection.load()
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["pk", "text", "document_id", "filename", "file_path", "upload_time", "file_type"],
limit=10
)
for result in results:
logger.debug(f"插入后查询结果: pk={result['pk']}, document_id={result['document_id']}, "
f"metadata={{'filename': '{result['filename']}', 'file_path': '{result['file_path']}', "
f"'upload_time': '{result['upload_time']}', 'file_type': '{result['file_type']}'}}")
except Exception as e:
logger.error(f"插入文档失败: {str(e)}")
raise RuntimeError(f"插入文档失败: {str(e)}")
return vector_store
except Exception as e:
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
raise
finally:
cleanup_milvus_connection()
def get_document_mapping(userid: str, db_type: str) -> Dict[str, Dict]:
"""
获取指定 userid db_type 下的 document_id 与元数据的映射
"""
try:
if not userid or "_" in userid:
raise ValueError("userid 不能为空且不能包含下划线")
if not db_type or "_" in db_type:
raise ValueError("db_type 不能为空且不能包含下划线")
initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return {}
collection = Collection(collection_name)
collection.load()
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
limit=100
)
mapping = {}
for result in results:
doc_id = result.get("document_id")
if doc_id:
mapping[doc_id] = {
"userid": result.get("userid", ""),
"filename": result.get("filename", ""),
"file_path": result.get("file_path", ""),
"upload_time": result.get("upload_time", ""),
"file_type": result.get("file_type", "")
}
logger.debug(f"document_id: {doc_id}, metadata: {mapping[doc_id]}")
logger.debug(f"找到 {len(mapping)} 个文档的映射")
return mapping
except Exception as e:
logger.error(f"获取文档映射失败: {str(e)}")
raise RuntimeError(f"获取文档映射失败: {str(e)}")
def list_user_collections() -> Dict[str, Dict]:
"""
列出所有数据库类型db_type及其包含的用户userid与对应的文档document_id映射
"""
try:
ensure_milvus_directory()
initialize_milvus_connection()
collections = utility.list_collections()
db_types_with_data = {}
for col in collections:
if col.startswith("ragdb_"):
db_type = col[len("ragdb_"):]
logger.debug(f"处理集合: {col} (db_type: {db_type})")
collection = Collection(col)
collection.load()
batch_size = 1000
offset = 0
user_document_map = {}
while True:
try:
results = collection.query(
expr="",
output_fields=["userid", "document_id"],
limit=batch_size,
offset=offset
)
if not results:
break
for result in results:
userid = result.get("userid")
doc_id = result.get("document_id")
if userid and doc_id:
if userid not in user_document_map:
user_document_map[userid] = set()
user_document_map[userid].add(doc_id)
offset += batch_size
except Exception as e:
logger.error(f"查询集合 {col} 的用户和文档失败: {str(e)}")
break
# 转换为列表以便序列化
user_document_map = {uid: list(doc_ids) for uid, doc_ids in user_document_map.items()}
logger.debug(f"集合 {col} 中找到用户和文档映射: {user_document_map}")
db_types_with_data[db_type] = {
"userids": user_document_map
}
logger.debug(f"可用 db_types 和数据: {db_types_with_data}")
return db_types_with_data
except Exception as e:
logger.error(f"列出集合和用户失败: {str(e)}")
raise
def view_collection_details(userid: str) -> None:
"""
查看特定 userid 在所有集合中的内容和容量包含 document_id 和元数据
"""
try:
if not userid or "_" in userid:
raise ValueError("userid 不能为空且不能包含下划线")
logger.debug(f"正在查看 userid {userid} 的集合")
ensure_milvus_directory()
initialize_milvus_connection()
collections = utility.list_collections()
db_types = [col[len("ragdb_"):] for col in collections if col.startswith("ragdb_")]
if not db_types:
logger.debug(f"未找到任何集合")
return
for db_type in db_types:
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
continue
collection = Collection(collection_name)
collection.load()
try:
all_pks = collection.query(expr=f"userid == '{userid}'", output_fields=["pk"], limit=10000)
num_entities = len(all_pks)
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["userid","text", "document_id", "filename", "file_path", "upload_time", "file_type"],
limit=10
)
logger.debug(f"集合 {collection_name} 中 userid {userid} 的文档数: {num_entities}")
if num_entities == 0:
logger.debug(f"集合 {collection_name} 中 userid {userid} 无文档")
continue
logger.debug(f"集合 {collection_name} 中 userid {userid} 的内容:")
for idx, doc in enumerate(results, 1):
metadata = {
"userid": doc.get("userid", ""),
"filename": doc.get("filename", ""),
"file_path": doc.get("file_path", ""),
"upload_time": doc.get("upload_time", ""),
"file_type": doc.get("file_type", "")
}
logger.debug(f"文档 {idx}: 内容: {doc.get('text', '')[:200]}..., 元数据: {metadata}")
except Exception as e:
logger.error(f"查询集合 {collection_name} 的文档失败: {str(e)}")
continue
except Exception as e:
logger.error(f"无法查看 userid {userid} 的集合详情: {str(e)}")
raise
def view_vector_data(db_type: str, userid: Optional[str] = None, document_id: Optional[str] = None, limit: int = 100) -> Dict[str, Dict]:
"""
查看指定 db_type 中的向量数据可选按 userid document_id 过滤包含完整元数据和向量
"""
try:
if not db_type or "_" in db_type:
raise ValueError("db_type 不能为空且不能包含下划线")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if userid and "_" in userid:
raise ValueError("userid 不能包含下划线")
if document_id and "_" in document_id:
raise ValueError("document_id 不能包含下划线")
initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return {}
collection = Collection(collection_name)
collection.load()
logger.debug(f"加载集合: {collection_name}")
expr = []
if userid:
expr.append(f"userid == '{userid}'")
if document_id:
expr.append(f"document_id == '{document_id}'")
expr = " && ".join(expr) if expr else ""
results = collection.query(
expr=expr,
output_fields=["pk", "text", "document_id", "vector", "filename", "file_path", "upload_time", "file_type"],
limit=limit
)
vector_data = {}
for doc in results:
pk = doc.get("pk", str(uuid.uuid4()))
text = doc.get("text", "")
doc_id = doc.get("document_id", "")
vector = doc.get("vector", [])
metadata = {
"filename": doc.get("filename", ""),
"file_path": doc.get("file_path", ""),
"upload_time": doc.get("upload_time", ""),
"file_type": doc.get("file_type", "")
}
vector_data[pk] = {
"text": text,
"vector": vector,
"document_id": doc_id,
"metadata": metadata
}
logger.debug(f"pk: {pk}, text: {text[:200]}..., document_id: {doc_id}, vector_length: {len(vector)}, metadata: {metadata}")
logger.debug(f"共找到 {len(vector_data)} 条向量数据")
return vector_data
except Exception as e:
logger.error(f"查看向量数据失败: {str(e)}")
raise
def main():
userid = "testuser1"
db_type = "textdb"
# logger.info("\n测试 1带文档初始化")
# documents = [
# Document(
# page_content="深度学习是基于深层神经网络的机器学习子集。",
# metadata={
# "userid": userid,
# "document_id": str(uuid.uuid4()),
# "filename": "test_doc1.txt",
# "file_path": "/path/to/test_doc1.txt",
# "upload_time": datetime.now().isoformat(),
# "file_type": "txt"
# }
# ),
# Document(
# page_content="知识图谱是一个结构化的语义知识库。",
# metadata={
# "userid": userid,
# "document_id": str(uuid.uuid4()),
# "filename": "test_doc2.txt",
# "file_path": "/path/to/test_doc2.txt",
# "upload_time": datetime.now().isoformat(),
# "file_type": "txt"
# }
# ),
# ]
#
# try:
# vector_store = get_vector_db(userid, db_type, documents=documents)
# logger.info(f"集合: ragdb_{db_type}")
# logger.info(f"成功为 userid {userid} 在 {db_type} 中插入文档")
# except Exception as e:
# logger.error(f"失败: {str(e)}")
logger.info("\n测试 2列出所有 db_types 和文档映射")
try:
db_types = list_user_collections()
logger.info(f"可用 db_types 和文档: {db_types}")
except Exception as e:
logger.error(f"失败: {str(e)}")
logger.info(f"\n测试 3查看 userid {userid} 的所有集合")
try:
view_collection_details(userid)
except Exception as e:
logger.error(f"失败: {str(e)}")
# logger.info(f"\n测试 4查看向量数据")
# try:
# vector_data = view_vector_data(db_type, userid=userid)
# logger.info(f"向量数据: {vector_data}")
# except Exception as e:
# logger.error(f"失败: {str(e)}")
logger.info(f"\n测试 5获取 userid {userid}{db_type}数据库的文档映射")
try:
mapping = get_document_mapping(userid, db_type)
logger.info(f"文档映射: {mapping}")
except Exception as e:
logger.error(f"失败: {str(e)}")
if __name__ == "__main__":
main()

1
rag/version.py Normal file
View File

@ -0,0 +1 @@
__version__ = '0.0.1'

4
requirements.txt Normal file
View File

@ -0,0 +1,4 @@
filetxt
apppublic
sqlor
ahserver

52
script/install.sh Executable file
View File

@ -0,0 +1,52 @@
#!/bin/bash
# 检查操作系统
OS=$(uname -s)
if [[ "$OS" != "Darwin" && "$OS" != "Linux" ]]; then
echo "错误:此脚本仅支持 macOS 和 Linux"
exit 1
fi
# 检查依赖文件
SERVICE_FILE="rag.service"
NGINX_FILE="rag.nginx"
if [[ ! -f "$SERVICE_FILE" || ! -f "$NGINX_FILE" ]]; then
echo "错误:缺少 $SERVICE_FILE$NGINX_FILE 文件"
exit 1
fi
# 1. 配置服务
if [[ "$OS" == "Darwin" ]]; then
# macOS: 使用 launchd
mkdir -p ~/Library/LaunchAgents
cp rag.service ~/Library/LaunchAgents/
launchctl load ~/Library/LaunchAgents/rag.service
launchctl start rag.service
elif [[ "$OS" == "Linux" ]]; then
# Linux: 使用 Systemd
sudo cp rag.service /etc/systemd/system/
sudo systemctl daemon-reload
sudo systemctl enable rag.service
sudo systemctl start rag.service
fi
# 2. 配置 Nginx
if ! command -v nginx &> /dev/null; then
echo "安装 Nginx..."
if [[ "$OS" == "Darwin" ]]; then
brew install nginx
elif [[ "$OS" == "Linux" ]]; then
sudo apt-get update && sudo apt-get install -y nginx
fi
fi
# 动态检测 Nginx 配置路径
NGINX_CONF_DIR="/etc/nginx/sites-enabled"
if [[ "$OS" == "Darwin" ]]; then
NGINX_CONF_DIR="/usr/local/etc/nginx/sites-enabled"
fi
mkdir -p "$NGINX_CONF_DIR"
cp rag.nginx "$NGINX_CONF_DIR/"
nginx -t && nginx -s reload || echo "错误Nginx 配置重载失败"
echo "安装完成!"

20
script/killname Executable file
View File

@ -0,0 +1,20 @@
#!/bin/sh
if [ -z "$1" ]; then
echo "错误:请提供进程名称"
exit 1
fi
# 查找进程并终止
PIDS=$(ps -ef | grep "$1" | grep -v grep | awk '{print $2}')
if [ -z "$PIDS" ]; then
echo "未找到匹配的进程:$1"
exit 0
fi
for PID in $PIDS; do
echo "终止进程 $PID"
kill -9 "$PID"
done
exit 0

31
script/rag.nginx Normal file
View File

@ -0,0 +1,31 @@
server {
listen 80;
server_name rag.opencomputing.cn;
autoindex on;
client_max_body_size 20m;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-server $host;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Scheme $scheme;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header X-Forwarded-Url "$scheme://$host:$server_port$request_uri";
index index.html index.htm;
location ~^/ip$ {
return 200 "$remote_addr";
}
location / {
add_header Access-Control-Allow-Origin *;
add_header Access-Control-Allow-Origin *;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-server $host;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Scheme $scheme;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header X-real-ip $remote_addr;
proxy_send_timeout 600s;
proxy_read_timeout 600s;
proxy_pass http://localhost:10098/;
}
}

19
script/rag.service Normal file
View File

@ -0,0 +1,19 @@
[Unit]
Description=RAG Service
Documentation=RAG service to control RAG application
After=network.target nginx.service
Requires=nginx.service
[Service]
User=wangmeihua
Group=wangmeihua
# Type=forking
User=wangmeihua
WorkingDirectory=/share/wangmeihua/rag
ExecStart=/bin/bash /share/wangmeihua/rag/script/rag.sh
ExecStop=/bin/bash /share/wangmeihua/rag/script/killname app/ragapp.py
Restart=on-failure
StandardOutput=append:/var/log/rag/rag.log
StandardError=append:/var/log/rag/error.log
[Install]
WantedBy=multi-user.target

18
script/rag.sh Executable file
View File

@ -0,0 +1,18 @@
#!/bin/bash
User=wangmeihua
Group=wangmeihua
PYTHON=python3
RAG_PY="/d/wangmeihua/rag/app/ragapp.py"
LOG_DIR="/d/wangmeihua/rag/logs"
# 验证文件存在
if [[ ! -f "$RAG_PY" ]]; then
echo "错误:$RAG_PY 不存在"
exit 1
fi
# 终止旧进程
"/d/wangmeihua/rag/script/killname" $RAG_PY
# 启动新进程
"$PYTHON" "$RAG_PY" -w "/d/wangmeihua/rag"

46
script/set_env.sh Normal file
View File

@ -0,0 +1,46 @@
#!/bin/bash
HOME_DIR="/share/wangmeihua"
RAG_DIR="/share/wangmeihua/rag"
PYTHON_VERSION="python3"
# 检查 Python 版本
if ! command -v "$PYTHON_VERSION" &> /dev/null; then
echo "错误:未找到 Python3"
exit 1
fi
# 检查 requirements.txt
if [[ ! -f "${RAG_DIR}/requirements.txt" ]]; then
echo "错误:${RAG_DIR}/requirements.txt 不存在"
exit 1
fi
# 创建虚拟环境
mkdir -p "${HOME_DIR}/bin"
"$PYTHON_VERSION" -m venv "${HOME_DIR}/py3"
source "${HOME_DIR}/py3/bin/activate"
# 备份 .bashrc
if [[ -f "${HOME_DIR}/.bashrc" ]]; then
cp "${HOME_DIR}/.bashrc" "${HOME_DIR}/.bashrc.bak"
fi
# 配置环境变量
cat >> "${HOME_DIR}/.bashrc" << EOF
export PATH="${HOME_DIR}/bin:${HOME_DIR}/py3/bin:\$PATH"
source "${HOME_DIR}/py3/bin/activate"
EOF
# 安装依赖
pip install -r "${RAG_DIR}/requirements.txt"
if [[ $? -ne 0 ]]; then
echo "错误:依赖安装失败"
exit 1
fi
# 复制并授权 killname
cp killname "${HOME_DIR}/bin"
chmod +x "${HOME_DIR}/bin/killname"
echo "环境配置完成!"

16
setup.cfg Normal file
View File

@ -0,0 +1,16 @@
[metadata]
name=rag
version = 0.0.1
description = Your project description
author = "yu moqing"
author_email = "yumoqing@gmail.com"
readme = "README.md"
license = "MIT"
[options]
packages = find:
requires-python = ">=3.8"
install_requires =
filetxt
apppublic
sqlor
ahserver

52
setup.py Executable file
View File

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
from rag.version import __version__
try:
from setuptools import setup
except ImportError:
from distutils.core import setup
required = []
with open('requirements.txt', 'r') as f:
ls = f.read()
required = ls.split('\n')
with open('rag/version.py', 'r') as f:
x = f.read()
y = x[x.index("'")+1:]
z = y[:y.index("'")]
version = z
with open("README.md", "r") as fh:
long_description = fh.read()
name = "rag"
description = "rag"
author = "yumoqing"
email = "yumoqing@gmail.com"
package_data = {}
setup(
name="rag",
version=version,
# uncomment the following lines if you fill them out in release.py
description=description,
author=author,
author_email=email,
platforms='any',
install_requires=required ,
packages=[
"rag"
],
package_data=package_data,
keywords = [
],
url="https://github.com/yumoqing/rag",
long_description=long_description,
long_description_content_type="text/markdown",
classifiers = [
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'License :: OSI Approved :: MIT License',
],
)

View File

@ -0,0 +1,41 @@
谷歌 industry 搜索引擎 org concept
知识图谱 Web 3.0 万维网 concept media
Web is a list of 网的 unk time
自顶向下 百科类网站 结构化数据源 concept media
结构化数据 <org> 关系数据库 concept media
非结构化数据 subclass of XML concept org
模式层 subclass of 知识图谱 concept media
结构化知识库 subclass of 知识图谱 concept misc
比尔盖茨 employer 微软 per org
5 信息抽取 facet of 数据层 media concept
信息抽取 part of 知识图谱 concept media
实体识别 subclass of 信息抽取 concept media
实体分类体系 part of 112种实体类别 concept misc
分类研究 实体类别 面向开放域的实体识别 concept media
服务器日志 特征建模 搜索引擎 concept org
关系抽取 subclass of Relation Extraction concept unk
模式匹配 实体 语料 concept media
属性抽取 <misc> 统计机器学习 concept media
属性 subclass of 实体 concept misc
数据挖掘 subclass of 结构化数据 concept media
拼图碎片 非结构化 信息抽取 concept media
歧义 used by 实体消歧 concept media
共指消解 自然语言处理 信息检索 concept misc
外部知识库 结构化数据 知识图谱 concept media
数据层的融合 模式层 关系数据库 concept media
资源描述框架 <media> 本体构建本体 concept org
DB2RDF subclass of 结构化的历史数据 cel date
自动化本体构建过程 本体库 数据驱动的自动化方式 concept media
阿里 owned by 阿里巴巴 org media
上下位关系 阿里巴巴 图谱 concept media
腾讯 owned by 阿里巴巴 org concept
知识图谱 location 城市 concept loc
串联 规则 推理策略的一环 concept media
算法 part of 知识库 concept media
知识库的更新 subclass of 概念层 concept media
知识图谱 part of 数据层 concept media
总结 part of 知识图谱 concept media
知识图谱 移动个人助理(Siri 智能语义搜索 concept media
(Sri) subclass of 的知识 eve unk
病毒 知识图谱 埃博拉病毒的症状有哪些 concept media
症状 part of 三元组 concept misc

View File

@ -0,0 +1,41 @@
谷歌 industry 搜索引擎 org concept
知识图谱 Web 3.0 万维网 concept media
Web is a list of 网的 unk time
自顶向下 百科类网站 结构化数据源 concept media
结构化数据 <org> 关系数据库 concept media
非结构化数据 subclass of XML concept org
模式层 subclass of 知识图谱 concept media
结构化知识库 subclass of 知识图谱 concept misc
比尔盖茨 employer 微软 per org
5 信息抽取 facet of 数据层 media concept
信息抽取 part of 知识图谱 concept media
实体识别 subclass of 信息抽取 concept media
实体分类体系 part of 112种实体类别 concept misc
分类研究 实体类别 面向开放域的实体识别 concept media
服务器日志 特征建模 搜索引擎 concept org
关系抽取 subclass of Relation Extraction concept unk
模式匹配 实体 语料 concept media
属性抽取 <misc> 统计机器学习 concept media
属性 subclass of 实体 concept misc
数据挖掘 subclass of 结构化数据 concept media
拼图碎片 非结构化 信息抽取 concept media
歧义 used by 实体消歧 concept media
共指消解 自然语言处理 信息检索 concept misc
外部知识库 结构化数据 知识图谱 concept media
数据层的融合 模式层 关系数据库 concept media
资源描述框架 <media> 本体构建本体 concept org
DB2RDF subclass of 结构化的历史数据 cel date
自动化本体构建过程 本体库 数据驱动的自动化方式 concept media
阿里 owned by 阿里巴巴 org media
上下位关系 阿里巴巴 图谱 concept media
腾讯 owned by 阿里巴巴 org concept
知识图谱 location 城市 concept loc
串联 规则 推理策略的一环 concept media
算法 part of 知识库 concept media
知识库的更新 subclass of 概念层 concept media
知识图谱 part of 数据层 concept media
总结 part of 知识图谱 concept media
知识图谱 移动个人助理(Siri 智能语义搜索 concept media
(Sri) subclass of 的知识 eve unk
病毒 知识图谱 埃博拉病毒的症状有哪些 concept media
症状 part of 三元组 concept misc

35
wwwroot/add.ui Normal file
View File

@ -0,0 +1,35 @@
{
"widgettype": "Form",
"options": {
"height": "70%",
"title": "向知识库添加文件",
"description": "支持的文件类型:.txt, .csv, .xlsx, .docx, .pptx, .pdf",
"method": "POST",
"submit_url": "{{entire_url('api/add')}}",
"fields": [
{
"name": "file_path",
"uitype": "file",
"label": "选择文件",
"required": true,
"description": "支持格式txt, csv, xlsx, docx, pptx, pdf"
},
{
"name": "userid",
"uitype": "str",
"label": "用户 ID",
"value": "user1",
"required": true,
"description": "请输入用户 ID不超过 100 字符)"
},
{
"name": "db_type",
"uitype": "str",
"label": "知识库名称",
"required": true,
"value": "textdb",
"description": "请输入知识库名称(如 textdb, pptdb"
}
]
}
}

Some files were not shown because too many files have changed in this diff Show More