rag-pipeline/core/retriever.py
yumoqing 4aaeb42035 feat: initial rag-pipeline service - pluggable RAG with KG support
- CLIP embedding (9086) + Milvus VDB (8886) + NetworkX graph (9092)
- BGE-Reranker (9090) for result reranking
- Hybrid retrieval: vector search + graph expansion + RRF fusion
- API: /api/ingest, /api/search, /api/pipelines, /api/plugins, /api/status
- Two pipelines: kg-rag-standard (full) and kg-rag-lite (vector only)
- Tested E2E: ingest + search with rerank_score=0.99
2026-06-15 20:42:33 +08:00

164 lines
5.5 KiB
Python

# -*- coding:utf-8 -*-
"""Hybrid retrieval: vector search + graph expansion + RRF fusion."""
from typing import List, Dict
def rrf_fusion(results_list: List[List[Dict]], k: int = 60) -> List[Dict]:
"""Reciprocal Rank Fusion - merge multiple ranked lists."""
scores = {}
for results in results_list:
for rank, item in enumerate(results):
doc_id = item.get("id", item.get("node_id", str(rank)))
rrf_score = 1.0 / (k + rank + 1)
if doc_id not in scores:
scores[doc_id] = {"item": item, "score": 0}
scores[doc_id]["score"] += rrf_score
sorted_items = sorted(scores.values(), key=lambda x: -x["score"])
for item in sorted_items:
item["item"]["rrf_score"] = round(item["score"], 6)
return [item["item"] for item in sorted_items]
def vector_search(query_embedding: List[float], vdb_plugin: Dict,
collection: str = "knowledge", top_k: int = 20) -> List[Dict]:
"""Search VDB for similar vectors."""
from plugins.registry import call_plugin
endpoint = vdb_plugin.get("endpoint")
if not endpoint:
return []
result = call_plugin(endpoint, "/v1/query", {
"colname": collection,
"vector": query_embedding,
"pagerows": top_k,
"page": 1
})
if "error" in result or result.get("status") != "SUCCEEDED":
return []
# VDB returns {"status": "SUCCEEDED", "data": {"rows": [...]}}
data = result.get("data", {})
return data.get("rows", []) if isinstance(data, dict) else []
def graph_expand(entity_ids: List[str], graph_plugin: Dict,
graph_name: str = "knowledge", hops: int = 2) -> List[Dict]:
"""Expand entities via graph neighbors."""
from plugins.registry import call_plugin
endpoint = graph_plugin.get("endpoint")
if not endpoint or graph_plugin.get("type") == "none":
return []
all_neighbors = []
seen = set()
for entity_id in entity_ids:
# Graph uses entity_{name} format from ingest
node_id = f"entity_{entity_id}".replace(" ", "_")
result = call_plugin(endpoint, "/api/graph/neighbors", {
"graph": graph_name,
"node_id": node_id,
"depth": hops
})
if "error" not in result:
for neighbor in result.get("neighbors", []):
nid = neighbor.get("node_id", "")
if nid and nid not in seen:
seen.add(nid)
neighbor["source_entity"] = entity_id
neighbor["id"] = nid
all_neighbors.append(neighbor)
return all_neighbors
def hybrid_retrieve(query_embedding: List[float],
extracted_entities: List[str],
pipeline: Dict,
collection: str = "knowledge",
graph_name: str = "knowledge") -> List[Dict]:
"""Hybrid retrieval: vector + graph + RRF fusion."""
plugins = pipeline.get("plugins", {})
retriever_config = plugins.get("retriever", {})
vector_top_k = retriever_config.get("vector_top_k", 20)
graph_hops = retriever_config.get("graph_hops", 2)
retriever_type = retriever_config.get("type", "hybrid")
results_lists = []
# Vector search
vdb_plugin = plugins.get("vdb", {})
vector_results = vector_search(query_embedding, vdb_plugin, collection, vector_top_k)
if vector_results:
results_lists.append(vector_results)
# Graph expansion (only if hybrid mode and graph enabled)
if retriever_type == "hybrid" and extracted_entities:
graph_plugin = plugins.get("graph", {})
graph_results = graph_expand(extracted_entities, graph_plugin, graph_name, graph_hops)
if graph_results:
results_lists.append(graph_results)
if not results_lists:
return []
if len(results_lists) == 1:
return results_lists[0]
return rrf_fusion(results_lists)
def rerank_results(query: str, results: List[Dict], reranker_plugin: Dict,
top_k: int = 5) -> List[Dict]:
"""Rerank results using BGE-Reranker."""
from plugins.registry import call_plugin
if not results or reranker_plugin.get("type") == "none":
return results[:top_k]
endpoint = reranker_plugin.get("endpoint")
if not endpoint:
return results[:top_k]
# Prepare documents for reranking
documents = []
for r in results:
doc = r.get("text", "") or r.get("description", "") or r.get("transcript", "")
if not doc:
doc = str(r.get("name", "")) + " " + str(r.get("attrs", {}))
documents.append(doc)
if not documents:
return results[:top_k]
rerank_result = call_plugin(endpoint, "/api/rerank", {
"query": query,
"documents": documents,
"top_k": top_k
})
if "error" in rerank_result:
return results[:top_k]
# Map reranked docs back to original results
ranked = []
for rd in rerank_result.get("ranked_docs", []):
doc_text = rd.get("doc", "")
for orig in results:
orig_text = orig.get("text", "") or orig.get("description", "") or str(orig.get("name", ""))
if doc_text == orig_text or doc_text in orig_text or orig_text in doc_text:
orig["rerank_score"] = rd.get("score", 0)
ranked.append(orig)
break
return ranked if ranked else results[:top_k]