rag
This commit is contained in:
parent
40625583c8
commit
35d0eb2234
@ -99,6 +99,7 @@ where a.orgid = b.orgid
|
||||
|
||||
# 获取服务参数
|
||||
service_params = await get_service_params(orgid)
|
||||
debug(f"服务参数是:{service_params}")
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
rollback_context["service_params"] = service_params
|
||||
|
||||
@ -10,7 +10,7 @@ from .ragllm_utils import get_ragllms_by_catelog
|
||||
from appPublic.registerfunction import RegisterFunction
|
||||
from sqlor.dbpools import DBPools
|
||||
from appPublic.registerfunction import RegisterFunction
|
||||
from rag.ragapi import docs, get_kdbs, fusedsearch
|
||||
from rag.ragapi import docs, get_kdbs, fusedsearch, textinsert
|
||||
|
||||
async def get_user_kdbs(request):
|
||||
env = request._run_ns
|
||||
@ -40,6 +40,7 @@ def load_rag():
|
||||
rf.register('docs', docs)
|
||||
rf.register('get_kdbs', get_kdbs)
|
||||
rf.register('fusedsearch', fusedsearch)
|
||||
rf.register('textinsert', textinsert)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -131,6 +131,7 @@ class RagOperations:
|
||||
start_milvus = time.time()
|
||||
for i in range(0, len(chunks_data), 10):
|
||||
batch_chunks = chunks_data[i:i + 10]
|
||||
debug(f"传入的数据是:{batch_chunks}")
|
||||
result = await self.api_service.milvus_insert_document(
|
||||
request=request,
|
||||
chunks=batch_chunks,
|
||||
@ -181,6 +182,37 @@ class RagOperations:
|
||||
|
||||
return chunks_data
|
||||
|
||||
async def insert_to_vector_text(self, request,
|
||||
db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
|
||||
"""插入单一纯文本到向量数据库,支持动态 schema"""
|
||||
|
||||
chunk_data = {}
|
||||
debug("准备单一纯文本数据并调用插入端点")
|
||||
start = time.time()
|
||||
for key, value in fields.items():
|
||||
chunk_data[key] = value
|
||||
|
||||
chunks_data = [chunk_data]
|
||||
debug(f"向量库插入传入的数据是:{chunks_data}")
|
||||
|
||||
# 调用 Milvus 插入
|
||||
result = await self.api_service.milvus_insert_document(
|
||||
request=request,
|
||||
chunks=chunks_data,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="milvus/insertdocument",
|
||||
user=userid,
|
||||
db_type=db_type
|
||||
)
|
||||
if result.get("status") != "success":
|
||||
raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||
|
||||
debug(f"成功插入纯文本到集合 {result.get('collection_name')}")
|
||||
timings["textinsert"] = time.time() - start
|
||||
debug(f"插入纯文本耗时: {timings['textinsert']:.2f} 秒")
|
||||
|
||||
return chunks_data
|
||||
|
||||
async def extract_triples(self, request, chunks: List[Document], service_params: Dict,
|
||||
userid: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||||
|
||||
129
rag/ragapi.py
129
rag/ragapi.py
@ -5,6 +5,7 @@ import time
|
||||
import traceback
|
||||
import json
|
||||
import math
|
||||
import uuid
|
||||
from rag.service_opts import get_service_params, sor_get_service_params
|
||||
from rag.rag_operations import RagOperations
|
||||
|
||||
@ -24,13 +25,42 @@ headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data: {
|
||||
"query": "什么是知识抽取。"
|
||||
"fiids":["1"]
|
||||
"query": "什么是知识抽取。",
|
||||
"fiids":["1"],
|
||||
"limit":5
|
||||
}
|
||||
|
||||
3、docs文档
|
||||
path: /v1/docs
|
||||
|
||||
4. 纯文本插入接口:
|
||||
path: /v1/textinsert
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data: {
|
||||
"text": "要插入的纯文本内容",
|
||||
"fiid": "知识库ID",
|
||||
"db_type": "数据库类型(如 textdb)"
|
||||
}
|
||||
response: {
|
||||
"status": "success",
|
||||
"userid": "用户组织ID",
|
||||
"collection_name": "ragdb_{dbtype}",
|
||||
"message": "文本成功嵌入并处理三元组",
|
||||
"status_code": 200,
|
||||
"unique_triples": ["提取的三元组列表"],
|
||||
"timings": {"total": 0.123}
|
||||
}
|
||||
error response: {
|
||||
"status": "error",
|
||||
"userid": "用户组织ID",
|
||||
"collection_name": "ragdb_{dbtype}",
|
||||
"message": "错误信息",
|
||||
"status_code": 400,
|
||||
"timings": {"total": 0.123}
|
||||
}
|
||||
|
||||
4. 添加用户消息到记忆:
|
||||
path: /v1/add_user_messages
|
||||
headers: {
|
||||
@ -90,7 +120,10 @@ async def get_kdbs(request, params_kw, *params, **kw):
|
||||
return result
|
||||
|
||||
async def fusedsearch(request, params_kw, *params):
|
||||
"""融合搜索,调用服务化端点"""
|
||||
"""
|
||||
融合搜索,调用服务化端点
|
||||
|
||||
"""
|
||||
kw = request._run_ns
|
||||
f = kw.get('get_userorgid')
|
||||
orgid = await f()
|
||||
@ -193,6 +226,96 @@ async def fusedsearch(request, params_kw, *params):
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# async def text_insert(text: str, fiid: str, orgid: str, db_type: str):
|
||||
async def textinsert(request, params_kw, *params):
|
||||
kw = request._run_ns
|
||||
f = kw.get('get_userorgid')
|
||||
orgid = await f()
|
||||
debug(f"orgid: {orgid},{f=}")
|
||||
f = kw.get('get_user')
|
||||
userid = await f()
|
||||
text = params_kw.get('text', '')
|
||||
fiid = params_kw.get('fiid')
|
||||
db_type = params_kw.get('db_type')
|
||||
id = str(uuid.uuid4())
|
||||
debug(f"params_kw: {params_kw}")
|
||||
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||||
# userid = "04J6VbxLqB_9RPMcgOv_8"
|
||||
# text = text
|
||||
# fiid = fiid
|
||||
# orgid = orgid
|
||||
# db_type = db_type
|
||||
# id = str(uuid.uuid4())
|
||||
# debug(f'Inserting document: text={text},userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}')
|
||||
|
||||
timings = {}
|
||||
start_total = time.time()
|
||||
|
||||
result = {
|
||||
"status": "error",
|
||||
"userid": orgid,
|
||||
"collection_name": "ragdb_{dbtype}",
|
||||
"message": "",
|
||||
"status_code": 400
|
||||
}
|
||||
|
||||
try:
|
||||
# 验证必填字段
|
||||
if not orgid or not fiid or not text or not db_type:
|
||||
raise ValueError("orgid、fiid、db_type 和 text 不能为空")
|
||||
if len(orgid) > 32 or len(fiid) > 255:
|
||||
raise ValueError("orgid 或 fiid 的长度超出限制")
|
||||
|
||||
# 获取服务参数
|
||||
service_params = await get_service_params(orgid)
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
rag_ops = RagOperations()
|
||||
# 生成嵌入向量
|
||||
embedding = await rag_ops.generate_query_vector(request, text, service_params, userid, timings)
|
||||
|
||||
# 插入 Milvus
|
||||
fields = {
|
||||
"text": text,
|
||||
"fiid": fiid,
|
||||
"orgid": orgid,
|
||||
"vector": embedding,
|
||||
"id": id
|
||||
}
|
||||
chunks_data = await rag_ops.insert_to_vector_text(request, db_type, fields, service_params, userid, timings)
|
||||
|
||||
# 抽取三元组
|
||||
document = Document(page_content=text)
|
||||
chunks = [document]
|
||||
triples = await rag_ops.extract_triples(
|
||||
request, chunks, service_params, userid, timings
|
||||
)
|
||||
|
||||
# 插入 Neo4j
|
||||
await rag_ops.insert_to_graph_db(
|
||||
request, triples, id, fiid, orgid, service_params, userid, timings
|
||||
)
|
||||
|
||||
timings["total"] = time.time() - start_total
|
||||
result.update({
|
||||
"status": "success",
|
||||
"unique_triples": triples,
|
||||
"message": f"文本成功嵌入并处理三元组",
|
||||
"status_code": 200
|
||||
})
|
||||
debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||
|
||||
except Exception as e:
|
||||
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
timings["total"] = time.time() - start_total
|
||||
result.update({
|
||||
"message": f"插入文档失败: {str(e)}",
|
||||
"timings": timings
|
||||
})
|
||||
raise ValueError(str(e)) from e
|
||||
debug(f"最终结果是:{result}")
|
||||
return result
|
||||
|
||||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||||
if fiids:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from ahserver.serverenv import get_serverenv
|
||||
from sqlor.dbpools import DBPools
|
||||
from appPublic.log import debug, error, info
|
||||
|
||||
async def sor_get_service_params(sor, orgid):
|
||||
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
||||
@ -70,6 +71,7 @@ async def sor_get_service_params(sor, orgid):
|
||||
|
||||
async def get_service_params(orgid):
|
||||
db = DBPools()
|
||||
debug(f"传入的orgid是:{orgid}")
|
||||
dbname = get_serverenv('get_module_dbname')('rag')
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
return await sor_get_service_params(sor, orgid)
|
||||
|
||||
@ -275,7 +275,7 @@ class APIService:
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"chunks": chunks,
|
||||
"dbtype": db_type
|
||||
"db_type": db_type
|
||||
}
|
||||
b = await uapi.call(upappid, apiname, user, params_kw)
|
||||
d = await self.handle_uapi_response(b, upappid, apiname, "Milvus 文档插入服务", request_id)
|
||||
@ -296,7 +296,7 @@ class APIService:
|
||||
"file_path": file_path,
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"document_id": document_id,
|
||||
"dbtype": db_type
|
||||
"db_type": db_type
|
||||
}
|
||||
b = await uapi.call(upappid, apiname, user, params_kw)
|
||||
d = await self.handle_uapi_response(b, upappid, apiname, "Milvus 文档删除服务", request_id)
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
fmgr = RagFileMgr(params_kw.fiid)
|
||||
try:
|
||||
stat = await fmgr.add_file(request, params_kw)
|
||||
return UiMessage(title='Add file', message='file add success')
|
||||
stat = await fmgr.add_file(request, params_kw)
|
||||
if stat.get('status_code') == 200:
|
||||
return UiMessage(title='Add file', message='file add success')
|
||||
else:
|
||||
error_message = stat.get('message', 'Unknown error occurred')
|
||||
return UiError(title='Add file', message=f'file add failed: {error_message}')
|
||||
except Exception as e:
|
||||
return UiError(title='Add file', message=f'file add failed({e})')
|
||||
return UiError(title='Add file', message=f'file add failed({e})')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user