rag
This commit is contained in:
parent
40625583c8
commit
35d0eb2234
@ -99,6 +99,7 @@ where a.orgid = b.orgid
|
|||||||
|
|
||||||
# 获取服务参数
|
# 获取服务参数
|
||||||
service_params = await get_service_params(orgid)
|
service_params = await get_service_params(orgid)
|
||||||
|
debug(f"服务参数是:{service_params}")
|
||||||
if not service_params:
|
if not service_params:
|
||||||
raise ValueError("无法获取服务参数")
|
raise ValueError("无法获取服务参数")
|
||||||
rollback_context["service_params"] = service_params
|
rollback_context["service_params"] = service_params
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from .ragllm_utils import get_ragllms_by_catelog
|
|||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
from sqlor.dbpools import DBPools
|
from sqlor.dbpools import DBPools
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
from rag.ragapi import docs, get_kdbs, fusedsearch
|
from rag.ragapi import docs, get_kdbs, fusedsearch, textinsert
|
||||||
|
|
||||||
async def get_user_kdbs(request):
|
async def get_user_kdbs(request):
|
||||||
env = request._run_ns
|
env = request._run_ns
|
||||||
@ -40,6 +40,7 @@ def load_rag():
|
|||||||
rf.register('docs', docs)
|
rf.register('docs', docs)
|
||||||
rf.register('get_kdbs', get_kdbs)
|
rf.register('get_kdbs', get_kdbs)
|
||||||
rf.register('fusedsearch', fusedsearch)
|
rf.register('fusedsearch', fusedsearch)
|
||||||
|
rf.register('textinsert', textinsert)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -131,6 +131,7 @@ class RagOperations:
|
|||||||
start_milvus = time.time()
|
start_milvus = time.time()
|
||||||
for i in range(0, len(chunks_data), 10):
|
for i in range(0, len(chunks_data), 10):
|
||||||
batch_chunks = chunks_data[i:i + 10]
|
batch_chunks = chunks_data[i:i + 10]
|
||||||
|
debug(f"传入的数据是:{batch_chunks}")
|
||||||
result = await self.api_service.milvus_insert_document(
|
result = await self.api_service.milvus_insert_document(
|
||||||
request=request,
|
request=request,
|
||||||
chunks=batch_chunks,
|
chunks=batch_chunks,
|
||||||
@ -181,6 +182,37 @@ class RagOperations:
|
|||||||
|
|
||||||
return chunks_data
|
return chunks_data
|
||||||
|
|
||||||
|
async def insert_to_vector_text(self, request,
|
||||||
|
db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
|
||||||
|
"""插入单一纯文本到向量数据库,支持动态 schema"""
|
||||||
|
|
||||||
|
chunk_data = {}
|
||||||
|
debug("准备单一纯文本数据并调用插入端点")
|
||||||
|
start = time.time()
|
||||||
|
for key, value in fields.items():
|
||||||
|
chunk_data[key] = value
|
||||||
|
|
||||||
|
chunks_data = [chunk_data]
|
||||||
|
debug(f"向量库插入传入的数据是:{chunks_data}")
|
||||||
|
|
||||||
|
# 调用 Milvus 插入
|
||||||
|
result = await self.api_service.milvus_insert_document(
|
||||||
|
request=request,
|
||||||
|
chunks=chunks_data,
|
||||||
|
upappid=service_params['vdb'],
|
||||||
|
apiname="milvus/insertdocument",
|
||||||
|
user=userid,
|
||||||
|
db_type=db_type
|
||||||
|
)
|
||||||
|
if result.get("status") != "success":
|
||||||
|
raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||||
|
|
||||||
|
debug(f"成功插入纯文本到集合 {result.get('collection_name')}")
|
||||||
|
timings["textinsert"] = time.time() - start
|
||||||
|
debug(f"插入纯文本耗时: {timings['textinsert']:.2f} 秒")
|
||||||
|
|
||||||
|
return chunks_data
|
||||||
|
|
||||||
async def extract_triples(self, request, chunks: List[Document], service_params: Dict,
|
async def extract_triples(self, request, chunks: List[Document], service_params: Dict,
|
||||||
userid: str, timings: Dict,
|
userid: str, timings: Dict,
|
||||||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||||||
|
|||||||
129
rag/ragapi.py
129
rag/ragapi.py
@ -5,6 +5,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
import uuid
|
||||||
from rag.service_opts import get_service_params, sor_get_service_params
|
from rag.service_opts import get_service_params, sor_get_service_params
|
||||||
from rag.rag_operations import RagOperations
|
from rag.rag_operations import RagOperations
|
||||||
|
|
||||||
@ -24,13 +25,42 @@ headers: {
|
|||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
data: {
|
data: {
|
||||||
"query": "什么是知识抽取。"
|
"query": "什么是知识抽取。",
|
||||||
"fiids":["1"]
|
"fiids":["1"],
|
||||||
|
"limit":5
|
||||||
}
|
}
|
||||||
|
|
||||||
3、docs文档
|
3、docs文档
|
||||||
path: /v1/docs
|
path: /v1/docs
|
||||||
|
|
||||||
|
4. 纯文本插入接口:
|
||||||
|
path: /v1/textinsert
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data: {
|
||||||
|
"text": "要插入的纯文本内容",
|
||||||
|
"fiid": "知识库ID",
|
||||||
|
"db_type": "数据库类型(如 textdb)"
|
||||||
|
}
|
||||||
|
response: {
|
||||||
|
"status": "success",
|
||||||
|
"userid": "用户组织ID",
|
||||||
|
"collection_name": "ragdb_{dbtype}",
|
||||||
|
"message": "文本成功嵌入并处理三元组",
|
||||||
|
"status_code": 200,
|
||||||
|
"unique_triples": ["提取的三元组列表"],
|
||||||
|
"timings": {"total": 0.123}
|
||||||
|
}
|
||||||
|
error response: {
|
||||||
|
"status": "error",
|
||||||
|
"userid": "用户组织ID",
|
||||||
|
"collection_name": "ragdb_{dbtype}",
|
||||||
|
"message": "错误信息",
|
||||||
|
"status_code": 400,
|
||||||
|
"timings": {"total": 0.123}
|
||||||
|
}
|
||||||
|
|
||||||
4. 添加用户消息到记忆:
|
4. 添加用户消息到记忆:
|
||||||
path: /v1/add_user_messages
|
path: /v1/add_user_messages
|
||||||
headers: {
|
headers: {
|
||||||
@ -90,7 +120,10 @@ async def get_kdbs(request, params_kw, *params, **kw):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def fusedsearch(request, params_kw, *params):
|
async def fusedsearch(request, params_kw, *params):
|
||||||
"""融合搜索,调用服务化端点"""
|
"""
|
||||||
|
融合搜索,调用服务化端点
|
||||||
|
|
||||||
|
"""
|
||||||
kw = request._run_ns
|
kw = request._run_ns
|
||||||
f = kw.get('get_userorgid')
|
f = kw.get('get_userorgid')
|
||||||
orgid = await f()
|
orgid = await f()
|
||||||
@ -193,6 +226,96 @@ async def fusedsearch(request, params_kw, *params):
|
|||||||
"error": str(e)
|
"error": str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# async def text_insert(text: str, fiid: str, orgid: str, db_type: str):
|
||||||
|
async def textinsert(request, params_kw, *params):
|
||||||
|
kw = request._run_ns
|
||||||
|
f = kw.get('get_userorgid')
|
||||||
|
orgid = await f()
|
||||||
|
debug(f"orgid: {orgid},{f=}")
|
||||||
|
f = kw.get('get_user')
|
||||||
|
userid = await f()
|
||||||
|
text = params_kw.get('text', '')
|
||||||
|
fiid = params_kw.get('fiid')
|
||||||
|
db_type = params_kw.get('db_type')
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
debug(f"params_kw: {params_kw}")
|
||||||
|
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||||||
|
# userid = "04J6VbxLqB_9RPMcgOv_8"
|
||||||
|
# text = text
|
||||||
|
# fiid = fiid
|
||||||
|
# orgid = orgid
|
||||||
|
# db_type = db_type
|
||||||
|
# id = str(uuid.uuid4())
|
||||||
|
# debug(f'Inserting document: text={text},userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}')
|
||||||
|
|
||||||
|
timings = {}
|
||||||
|
start_total = time.time()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"status": "error",
|
||||||
|
"userid": orgid,
|
||||||
|
"collection_name": "ragdb_{dbtype}",
|
||||||
|
"message": "",
|
||||||
|
"status_code": 400
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证必填字段
|
||||||
|
if not orgid or not fiid or not text or not db_type:
|
||||||
|
raise ValueError("orgid、fiid、db_type 和 text 不能为空")
|
||||||
|
if len(orgid) > 32 or len(fiid) > 255:
|
||||||
|
raise ValueError("orgid 或 fiid 的长度超出限制")
|
||||||
|
|
||||||
|
# 获取服务参数
|
||||||
|
service_params = await get_service_params(orgid)
|
||||||
|
if not service_params:
|
||||||
|
raise ValueError("无法获取服务参数")
|
||||||
|
rag_ops = RagOperations()
|
||||||
|
# 生成嵌入向量
|
||||||
|
embedding = await rag_ops.generate_query_vector(request, text, service_params, userid, timings)
|
||||||
|
|
||||||
|
# 插入 Milvus
|
||||||
|
fields = {
|
||||||
|
"text": text,
|
||||||
|
"fiid": fiid,
|
||||||
|
"orgid": orgid,
|
||||||
|
"vector": embedding,
|
||||||
|
"id": id
|
||||||
|
}
|
||||||
|
chunks_data = await rag_ops.insert_to_vector_text(request, db_type, fields, service_params, userid, timings)
|
||||||
|
|
||||||
|
# 抽取三元组
|
||||||
|
document = Document(page_content=text)
|
||||||
|
chunks = [document]
|
||||||
|
triples = await rag_ops.extract_triples(
|
||||||
|
request, chunks, service_params, userid, timings
|
||||||
|
)
|
||||||
|
|
||||||
|
# 插入 Neo4j
|
||||||
|
await rag_ops.insert_to_graph_db(
|
||||||
|
request, triples, id, fiid, orgid, service_params, userid, timings
|
||||||
|
)
|
||||||
|
|
||||||
|
timings["total"] = time.time() - start_total
|
||||||
|
result.update({
|
||||||
|
"status": "success",
|
||||||
|
"unique_triples": triples,
|
||||||
|
"message": f"文本成功嵌入并处理三元组",
|
||||||
|
"status_code": 200
|
||||||
|
})
|
||||||
|
debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
timings["total"] = time.time() - start_total
|
||||||
|
result.update({
|
||||||
|
"message": f"插入文档失败: {str(e)}",
|
||||||
|
"timings": timings
|
||||||
|
})
|
||||||
|
raise ValueError(str(e)) from e
|
||||||
|
debug(f"最终结果是:{result}")
|
||||||
|
return result
|
||||||
|
|
||||||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||||||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||||||
if fiids:
|
if fiids:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from ahserver.serverenv import get_serverenv
|
from ahserver.serverenv import get_serverenv
|
||||||
from sqlor.dbpools import DBPools
|
from sqlor.dbpools import DBPools
|
||||||
|
from appPublic.log import debug, error, info
|
||||||
|
|
||||||
async def sor_get_service_params(sor, orgid):
|
async def sor_get_service_params(sor, orgid):
|
||||||
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
||||||
@ -70,6 +71,7 @@ async def sor_get_service_params(sor, orgid):
|
|||||||
|
|
||||||
async def get_service_params(orgid):
|
async def get_service_params(orgid):
|
||||||
db = DBPools()
|
db = DBPools()
|
||||||
|
debug(f"传入的orgid是:{orgid}")
|
||||||
dbname = get_serverenv('get_module_dbname')('rag')
|
dbname = get_serverenv('get_module_dbname')('rag')
|
||||||
async with db.sqlorContext(dbname) as sor:
|
async with db.sqlorContext(dbname) as sor:
|
||||||
return await sor_get_service_params(sor, orgid)
|
return await sor_get_service_params(sor, orgid)
|
||||||
|
|||||||
@ -275,7 +275,7 @@ class APIService:
|
|||||||
uapi = UAPI(request, DictObject(**globals()))
|
uapi = UAPI(request, DictObject(**globals()))
|
||||||
params_kw = {
|
params_kw = {
|
||||||
"chunks": chunks,
|
"chunks": chunks,
|
||||||
"dbtype": db_type
|
"db_type": db_type
|
||||||
}
|
}
|
||||||
b = await uapi.call(upappid, apiname, user, params_kw)
|
b = await uapi.call(upappid, apiname, user, params_kw)
|
||||||
d = await self.handle_uapi_response(b, upappid, apiname, "Milvus 文档插入服务", request_id)
|
d = await self.handle_uapi_response(b, upappid, apiname, "Milvus 文档插入服务", request_id)
|
||||||
@ -296,7 +296,7 @@ class APIService:
|
|||||||
"file_path": file_path,
|
"file_path": file_path,
|
||||||
"knowledge_base_id": knowledge_base_id,
|
"knowledge_base_id": knowledge_base_id,
|
||||||
"document_id": document_id,
|
"document_id": document_id,
|
||||||
"dbtype": db_type
|
"db_type": db_type
|
||||||
}
|
}
|
||||||
b = await uapi.call(upappid, apiname, user, params_kw)
|
b = await uapi.call(upappid, apiname, user, params_kw)
|
||||||
d = await self.handle_uapi_response(b, upappid, apiname, "Milvus 文档删除服务", request_id)
|
d = await self.handle_uapi_response(b, upappid, apiname, "Milvus 文档删除服务", request_id)
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
fmgr = RagFileMgr(params_kw.fiid)
|
fmgr = RagFileMgr(params_kw.fiid)
|
||||||
try:
|
try:
|
||||||
stat = await fmgr.add_file(request, params_kw)
|
stat = await fmgr.add_file(request, params_kw)
|
||||||
|
if stat.get('status_code') == 200:
|
||||||
return UiMessage(title='Add file', message='file add success')
|
return UiMessage(title='Add file', message='file add success')
|
||||||
|
else:
|
||||||
|
error_message = stat.get('message', 'Unknown error occurred')
|
||||||
|
return UiError(title='Add file', message=f'file add failed: {error_message}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return UiError(title='Add file', message=f'file add failed({e})')
|
return UiError(title='Add file', message=f'file add failed({e})')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user