rag-pipeline/pipeline.py
yumoqing 4aaeb42035 feat: initial rag-pipeline service - pluggable RAG with KG support
- CLIP embedding (9086) + Milvus VDB (8886) + NetworkX graph (9092)
- BGE-Reranker (9090) for result reranking
- Hybrid retrieval: vector search + graph expansion + RRF fusion
- API: /api/ingest, /api/search, /api/pipelines, /api/plugins, /api/status
- Two pipelines: kg-rag-standard (full) and kg-rag-lite (vector only)
- Tested E2E: ingest + search with rerank_score=0.99
2026-06-15 20:42:33 +08:00

271 lines
8.8 KiB
Python

# -*- coding:utf-8 -*-
"""RAG Pipeline orchestration - ingest and search workflows."""
import json
import time
import uuid
from typing import Dict, List, Optional
from traceback import format_exc
from plugins.registry import get_pipeline, call_plugin
from core.chunker import chunk_document
from core.extractor import extract_entities_relations
from core.retriever import hybrid_retrieve, rerank_results
def _get_embedding(texts: List[str], embedding_plugin: Dict) -> List[List[float]]:
"""Get text embeddings from CLIP service."""
endpoint = embedding_plugin.get("endpoint")
if not endpoint:
return []
result = call_plugin(endpoint, "/api/text", {"texts": texts})
if "error" in result:
return []
return result.get("embeddings", [])
def _ensure_collection(vdb_plugin: Dict, collection: str, dim: int = 1024) -> Dict:
"""Ensure VDB collection exists with correct schema."""
endpoint = vdb_plugin.get("endpoint")
if not endpoint:
return {"error": "VDB endpoint not configured"}
# Check if collection exists
list_result = call_plugin(endpoint, "/v1/listcollections", {})
existing = list_result.get("collections", [])
if collection in existing:
return {"status": "exists"}
# Create with proper field schema
fields = [
{"name": "id", "type": "str", "is_primary": True, "max_length": 256},
{"name": "text", "type": "str", "max_length": 65535},
{"name": "description", "type": "str", "max_length": 65535},
{"name": "type", "type": "str", "max_length": 128},
{"name": "embedding", "type": "fvector", "dim": dim}
]
result = call_plugin(endpoint, "/v1/createcollection", {
"colname": collection,
"fields": fields,
"description": "RAG knowledge collection",
"metric": "COSINE"
})
return result
def _store_in_vdb(items: List[Dict], embeddings: List[List[float]],
vdb_plugin: Dict, collection: str = "knowledge") -> Dict:
"""Store items with embeddings in VDB."""
endpoint = vdb_plugin.get("endpoint")
if not endpoint:
return {"error": "VDB endpoint not configured"}
# Ensure collection exists
ensure_result = _ensure_collection(vdb_plugin, collection)
if "error" in ensure_result:
return ensure_result
# Prepare batch insert - data must match schema fields
rows = []
for item, emb in zip(items, embeddings):
row = {
"id": item.get("id", str(uuid.uuid4())[:16]),
"text": item.get("text", ""),
"description": item.get("description", ""),
"type": item.get("type", "chunk"),
"embedding": emb
}
rows.append(row)
result = call_plugin(endpoint, "/v1/batchinsert", {
"colname": collection,
"data": rows
})
return result
def _store_in_graph(entities: List[Dict], relations: List[Dict],
graph_plugin: Dict, graph_name: str = "knowledge") -> Dict:
"""Store entities and relations in graph."""
endpoint = graph_plugin.get("endpoint")
if not endpoint or graph_plugin.get("type") == "none":
return {"status": "skipped", "reason": "graph disabled"}
added_nodes = 0
added_edges = 0
for entity in entities:
name = entity.get("name", "")
node_id = f"entity_{name}".replace(" ", "_")
result = call_plugin(endpoint, "/api/graph/add_node", {
"graph": graph_name,
"node_id": node_id,
"attrs": {
"name": name,
"type": entity.get("type", "unknown"),
"description": entity.get("description", "")
}
})
if "error" not in result:
added_nodes += 1
for rel in relations:
source = f"entity_{rel.get('source', '')}".replace(" ", "_")
target = f"entity_{rel.get('target', '')}".replace(" ", "_")
result = call_plugin(endpoint, "/api/graph/add_edge", {
"graph": graph_name,
"source": source,
"target": target,
"attrs": {
"relation": rel.get("relation", "related_to"),
"description": rel.get("description", "")
}
})
if "error" not in result:
added_edges += 1
# Save graph to disk
call_plugin(endpoint, "/api/graph/save", {"graph": graph_name})
return {"status": "ok", "nodes_added": added_nodes, "edges_added": added_edges}
def ingest(document: str, pipeline_name: str = "kg-rag-standard",
collection: str = "knowledge", graph_name: str = "knowledge",
llm_func=None) -> Dict:
"""Full ingest pipeline: chunk -> embed -> extract -> store."""
start_time = time.time()
pipeline = get_pipeline(pipeline_name)
plugins = pipeline.get("plugins", {})
# Step 1: Chunk
chunker_config = plugins.get("chunker", {})
chunks = chunk_document(document,
strategy=chunker_config.get("type", "recursive"),
chunk_size=chunker_config.get("chunk_size", 512),
overlap=chunker_config.get("overlap", 64))
chunk_texts = [c["text"] for c in chunks]
# Step 2: Embed
embedding_plugin = plugins.get("embedding", {})
embeddings = _get_embedding(chunk_texts, embedding_plugin)
if not embeddings:
return {"error": "Failed to get embeddings", "chunks": len(chunks)}
# Step 3: Store chunks + embeddings in VDB
vdb_plugin = plugins.get("vdb", {})
vdb_result = _store_in_vdb(chunks, embeddings, vdb_plugin, collection)
# Step 4: Extract entities and relations
extractor_config = plugins.get("extractor", {})
extraction = {"entities": [], "relations": []}
if extractor_config.get("type") != "none" and llm_func:
extraction = extract_entities_relations(document[:3000], llm_func)
# Step 5: Store in graph
graph_plugin = plugins.get("graph", {})
graph_result = _store_in_graph(
extraction["entities"],
extraction["relations"],
graph_plugin,
graph_name
)
else:
graph_result = {"status": "skipped"}
elapsed = round(time.time() - start_time, 3)
return {
"status": "SUCCEEDED",
"pipeline": pipeline_name,
"chunks": len(chunks),
"embeddings": len(embeddings),
"entities": len(extraction.get("entities", [])),
"relations": len(extraction.get("relations", [])),
"vdb_result": vdb_result,
"graph_result": graph_result,
"elapsed": elapsed
}
def search(query: str, pipeline_name: str = "kg-rag-standard",
collection: str = "knowledge", graph_name: str = "knowledge",
top_k: int = 5, llm_func=None) -> Dict:
"""Full search pipeline: embed -> retrieve -> rerank -> generate."""
start_time = time.time()
pipeline = get_pipeline(pipeline_name)
plugins = pipeline.get("plugins", {})
# Step 1: Embed query
embedding_plugin = plugins.get("embedding", {})
query_embeddings = _get_embedding([query], embedding_plugin)
if not query_embeddings:
return {"error": "Failed to embed query"}
query_embedding = query_embeddings[0]
# Step 2: Extract entities from query (for graph expansion)
extracted_entities = []
if llm_func:
quick_extraction = extract_entities_relations(query, llm_func)
extracted_entities = [e["name"] for e in quick_extraction.get("entities", [])]
# Step 3: Hybrid retrieval
results = hybrid_retrieve(
query_embedding,
extracted_entities,
pipeline,
collection,
graph_name
)
# Step 4: Rerank
reranker_plugin = plugins.get("reranker", {})
ranked_results = rerank_results(query, results, reranker_plugin, top_k)
# Step 5: Generate answer (optional)
answer = None
if llm_func and ranked_results:
context = "\n\n".join([
r.get("text", "") or r.get("description", "") or str(r)
for r in ranked_results[:top_k]
])
gen_prompt = f"""基于以下上下文回答问题。如果上下文中没有相关信息,请如实说明。
上下文:
{context[:2000]}
问题:{query}
请用中文回答:"""
try:
answer = llm_func(gen_prompt)
except:
answer = None
elapsed = round(time.time() - start_time, 3)
return {
"status": "SUCCEEDED",
"pipeline": pipeline_name,
"query": query,
"results": ranked_results[:top_k],
"total_retrieved": len(results),
"answer": answer,
"extracted_entities": extracted_entities,
"elapsed": elapsed
}