- CLIP embedding (9086) + Milvus VDB (8886) + NetworkX graph (9092) - BGE-Reranker (9090) for result reranking - Hybrid retrieval: vector search + graph expansion + RRF fusion - API: /api/ingest, /api/search, /api/pipelines, /api/plugins, /api/status - Two pipelines: kg-rag-standard (full) and kg-rag-lite (vector only) - Tested E2E: ingest + search with rerank_score=0.99
271 lines
8.8 KiB
Python
271 lines
8.8 KiB
Python
# -*- coding:utf-8 -*-
|
|
"""RAG Pipeline orchestration - ingest and search workflows."""
|
|
import json
|
|
import time
|
|
import uuid
|
|
from typing import Dict, List, Optional
|
|
from traceback import format_exc
|
|
|
|
from plugins.registry import get_pipeline, call_plugin
|
|
from core.chunker import chunk_document
|
|
from core.extractor import extract_entities_relations
|
|
from core.retriever import hybrid_retrieve, rerank_results
|
|
|
|
|
|
def _get_embedding(texts: List[str], embedding_plugin: Dict) -> List[List[float]]:
|
|
"""Get text embeddings from CLIP service."""
|
|
endpoint = embedding_plugin.get("endpoint")
|
|
if not endpoint:
|
|
return []
|
|
|
|
result = call_plugin(endpoint, "/api/text", {"texts": texts})
|
|
if "error" in result:
|
|
return []
|
|
|
|
return result.get("embeddings", [])
|
|
|
|
|
|
def _ensure_collection(vdb_plugin: Dict, collection: str, dim: int = 1024) -> Dict:
|
|
"""Ensure VDB collection exists with correct schema."""
|
|
endpoint = vdb_plugin.get("endpoint")
|
|
if not endpoint:
|
|
return {"error": "VDB endpoint not configured"}
|
|
|
|
# Check if collection exists
|
|
list_result = call_plugin(endpoint, "/v1/listcollections", {})
|
|
existing = list_result.get("collections", [])
|
|
|
|
if collection in existing:
|
|
return {"status": "exists"}
|
|
|
|
# Create with proper field schema
|
|
fields = [
|
|
{"name": "id", "type": "str", "is_primary": True, "max_length": 256},
|
|
{"name": "text", "type": "str", "max_length": 65535},
|
|
{"name": "description", "type": "str", "max_length": 65535},
|
|
{"name": "type", "type": "str", "max_length": 128},
|
|
{"name": "embedding", "type": "fvector", "dim": dim}
|
|
]
|
|
|
|
result = call_plugin(endpoint, "/v1/createcollection", {
|
|
"colname": collection,
|
|
"fields": fields,
|
|
"description": "RAG knowledge collection",
|
|
"metric": "COSINE"
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
def _store_in_vdb(items: List[Dict], embeddings: List[List[float]],
|
|
vdb_plugin: Dict, collection: str = "knowledge") -> Dict:
|
|
"""Store items with embeddings in VDB."""
|
|
endpoint = vdb_plugin.get("endpoint")
|
|
if not endpoint:
|
|
return {"error": "VDB endpoint not configured"}
|
|
|
|
# Ensure collection exists
|
|
ensure_result = _ensure_collection(vdb_plugin, collection)
|
|
if "error" in ensure_result:
|
|
return ensure_result
|
|
|
|
# Prepare batch insert - data must match schema fields
|
|
rows = []
|
|
for item, emb in zip(items, embeddings):
|
|
row = {
|
|
"id": item.get("id", str(uuid.uuid4())[:16]),
|
|
"text": item.get("text", ""),
|
|
"description": item.get("description", ""),
|
|
"type": item.get("type", "chunk"),
|
|
"embedding": emb
|
|
}
|
|
rows.append(row)
|
|
|
|
result = call_plugin(endpoint, "/v1/batchinsert", {
|
|
"colname": collection,
|
|
"data": rows
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
def _store_in_graph(entities: List[Dict], relations: List[Dict],
|
|
graph_plugin: Dict, graph_name: str = "knowledge") -> Dict:
|
|
"""Store entities and relations in graph."""
|
|
endpoint = graph_plugin.get("endpoint")
|
|
if not endpoint or graph_plugin.get("type") == "none":
|
|
return {"status": "skipped", "reason": "graph disabled"}
|
|
|
|
added_nodes = 0
|
|
added_edges = 0
|
|
|
|
for entity in entities:
|
|
name = entity.get("name", "")
|
|
node_id = f"entity_{name}".replace(" ", "_")
|
|
result = call_plugin(endpoint, "/api/graph/add_node", {
|
|
"graph": graph_name,
|
|
"node_id": node_id,
|
|
"attrs": {
|
|
"name": name,
|
|
"type": entity.get("type", "unknown"),
|
|
"description": entity.get("description", "")
|
|
}
|
|
})
|
|
if "error" not in result:
|
|
added_nodes += 1
|
|
|
|
for rel in relations:
|
|
source = f"entity_{rel.get('source', '')}".replace(" ", "_")
|
|
target = f"entity_{rel.get('target', '')}".replace(" ", "_")
|
|
result = call_plugin(endpoint, "/api/graph/add_edge", {
|
|
"graph": graph_name,
|
|
"source": source,
|
|
"target": target,
|
|
"attrs": {
|
|
"relation": rel.get("relation", "related_to"),
|
|
"description": rel.get("description", "")
|
|
}
|
|
})
|
|
if "error" not in result:
|
|
added_edges += 1
|
|
|
|
# Save graph to disk
|
|
call_plugin(endpoint, "/api/graph/save", {"graph": graph_name})
|
|
|
|
return {"status": "ok", "nodes_added": added_nodes, "edges_added": added_edges}
|
|
|
|
|
|
def ingest(document: str, pipeline_name: str = "kg-rag-standard",
|
|
collection: str = "knowledge", graph_name: str = "knowledge",
|
|
llm_func=None) -> Dict:
|
|
"""Full ingest pipeline: chunk -> embed -> extract -> store."""
|
|
start_time = time.time()
|
|
|
|
pipeline = get_pipeline(pipeline_name)
|
|
plugins = pipeline.get("plugins", {})
|
|
|
|
# Step 1: Chunk
|
|
chunker_config = plugins.get("chunker", {})
|
|
chunks = chunk_document(document,
|
|
strategy=chunker_config.get("type", "recursive"),
|
|
chunk_size=chunker_config.get("chunk_size", 512),
|
|
overlap=chunker_config.get("overlap", 64))
|
|
|
|
chunk_texts = [c["text"] for c in chunks]
|
|
|
|
# Step 2: Embed
|
|
embedding_plugin = plugins.get("embedding", {})
|
|
embeddings = _get_embedding(chunk_texts, embedding_plugin)
|
|
|
|
if not embeddings:
|
|
return {"error": "Failed to get embeddings", "chunks": len(chunks)}
|
|
|
|
# Step 3: Store chunks + embeddings in VDB
|
|
vdb_plugin = plugins.get("vdb", {})
|
|
vdb_result = _store_in_vdb(chunks, embeddings, vdb_plugin, collection)
|
|
|
|
# Step 4: Extract entities and relations
|
|
extractor_config = plugins.get("extractor", {})
|
|
extraction = {"entities": [], "relations": []}
|
|
|
|
if extractor_config.get("type") != "none" and llm_func:
|
|
extraction = extract_entities_relations(document[:3000], llm_func)
|
|
|
|
# Step 5: Store in graph
|
|
graph_plugin = plugins.get("graph", {})
|
|
graph_result = _store_in_graph(
|
|
extraction["entities"],
|
|
extraction["relations"],
|
|
graph_plugin,
|
|
graph_name
|
|
)
|
|
else:
|
|
graph_result = {"status": "skipped"}
|
|
|
|
elapsed = round(time.time() - start_time, 3)
|
|
|
|
return {
|
|
"status": "SUCCEEDED",
|
|
"pipeline": pipeline_name,
|
|
"chunks": len(chunks),
|
|
"embeddings": len(embeddings),
|
|
"entities": len(extraction.get("entities", [])),
|
|
"relations": len(extraction.get("relations", [])),
|
|
"vdb_result": vdb_result,
|
|
"graph_result": graph_result,
|
|
"elapsed": elapsed
|
|
}
|
|
|
|
|
|
def search(query: str, pipeline_name: str = "kg-rag-standard",
|
|
collection: str = "knowledge", graph_name: str = "knowledge",
|
|
top_k: int = 5, llm_func=None) -> Dict:
|
|
"""Full search pipeline: embed -> retrieve -> rerank -> generate."""
|
|
start_time = time.time()
|
|
|
|
pipeline = get_pipeline(pipeline_name)
|
|
plugins = pipeline.get("plugins", {})
|
|
|
|
# Step 1: Embed query
|
|
embedding_plugin = plugins.get("embedding", {})
|
|
query_embeddings = _get_embedding([query], embedding_plugin)
|
|
|
|
if not query_embeddings:
|
|
return {"error": "Failed to embed query"}
|
|
|
|
query_embedding = query_embeddings[0]
|
|
|
|
# Step 2: Extract entities from query (for graph expansion)
|
|
extracted_entities = []
|
|
if llm_func:
|
|
quick_extraction = extract_entities_relations(query, llm_func)
|
|
extracted_entities = [e["name"] for e in quick_extraction.get("entities", [])]
|
|
|
|
# Step 3: Hybrid retrieval
|
|
results = hybrid_retrieve(
|
|
query_embedding,
|
|
extracted_entities,
|
|
pipeline,
|
|
collection,
|
|
graph_name
|
|
)
|
|
|
|
# Step 4: Rerank
|
|
reranker_plugin = plugins.get("reranker", {})
|
|
ranked_results = rerank_results(query, results, reranker_plugin, top_k)
|
|
|
|
# Step 5: Generate answer (optional)
|
|
answer = None
|
|
if llm_func and ranked_results:
|
|
context = "\n\n".join([
|
|
r.get("text", "") or r.get("description", "") or str(r)
|
|
for r in ranked_results[:top_k]
|
|
])
|
|
|
|
gen_prompt = f"""基于以下上下文回答问题。如果上下文中没有相关信息,请如实说明。
|
|
|
|
上下文:
|
|
{context[:2000]}
|
|
|
|
问题:{query}
|
|
|
|
请用中文回答:"""
|
|
|
|
try:
|
|
answer = llm_func(gen_prompt)
|
|
except:
|
|
answer = None
|
|
|
|
elapsed = round(time.time() - start_time, 3)
|
|
|
|
return {
|
|
"status": "SUCCEEDED",
|
|
"pipeline": pipeline_name,
|
|
"query": query,
|
|
"results": ranked_results[:top_k],
|
|
"total_retrieved": len(results),
|
|
"answer": answer,
|
|
"extracted_entities": extracted_entities,
|
|
"elapsed": elapsed
|
|
}
|