- CLIP embedding (9086) + Milvus VDB (8886) + NetworkX graph (9092) - BGE-Reranker (9090) for result reranking - Hybrid retrieval: vector search + graph expansion + RRF fusion - API: /api/ingest, /api/search, /api/pipelines, /api/plugins, /api/status - Two pipelines: kg-rag-standard (full) and kg-rag-lite (vector only) - Tested E2E: ingest + search with rerank_score=0.99
164 lines
5.5 KiB
Python
164 lines
5.5 KiB
Python
# -*- coding:utf-8 -*-
|
|
"""Hybrid retrieval: vector search + graph expansion + RRF fusion."""
|
|
from typing import List, Dict
|
|
|
|
def rrf_fusion(results_list: List[List[Dict]], k: int = 60) -> List[Dict]:
|
|
"""Reciprocal Rank Fusion - merge multiple ranked lists."""
|
|
scores = {}
|
|
|
|
for results in results_list:
|
|
for rank, item in enumerate(results):
|
|
doc_id = item.get("id", item.get("node_id", str(rank)))
|
|
rrf_score = 1.0 / (k + rank + 1)
|
|
|
|
if doc_id not in scores:
|
|
scores[doc_id] = {"item": item, "score": 0}
|
|
scores[doc_id]["score"] += rrf_score
|
|
|
|
sorted_items = sorted(scores.values(), key=lambda x: -x["score"])
|
|
|
|
for item in sorted_items:
|
|
item["item"]["rrf_score"] = round(item["score"], 6)
|
|
|
|
return [item["item"] for item in sorted_items]
|
|
|
|
|
|
def vector_search(query_embedding: List[float], vdb_plugin: Dict,
|
|
collection: str = "knowledge", top_k: int = 20) -> List[Dict]:
|
|
"""Search VDB for similar vectors."""
|
|
from plugins.registry import call_plugin
|
|
|
|
endpoint = vdb_plugin.get("endpoint")
|
|
if not endpoint:
|
|
return []
|
|
|
|
result = call_plugin(endpoint, "/v1/query", {
|
|
"colname": collection,
|
|
"vector": query_embedding,
|
|
"pagerows": top_k,
|
|
"page": 1
|
|
})
|
|
|
|
if "error" in result or result.get("status") != "SUCCEEDED":
|
|
return []
|
|
|
|
# VDB returns {"status": "SUCCEEDED", "data": {"rows": [...]}}
|
|
data = result.get("data", {})
|
|
return data.get("rows", []) if isinstance(data, dict) else []
|
|
|
|
|
|
def graph_expand(entity_ids: List[str], graph_plugin: Dict,
|
|
graph_name: str = "knowledge", hops: int = 2) -> List[Dict]:
|
|
"""Expand entities via graph neighbors."""
|
|
from plugins.registry import call_plugin
|
|
|
|
endpoint = graph_plugin.get("endpoint")
|
|
if not endpoint or graph_plugin.get("type") == "none":
|
|
return []
|
|
|
|
all_neighbors = []
|
|
seen = set()
|
|
|
|
for entity_id in entity_ids:
|
|
# Graph uses entity_{name} format from ingest
|
|
node_id = f"entity_{entity_id}".replace(" ", "_")
|
|
result = call_plugin(endpoint, "/api/graph/neighbors", {
|
|
"graph": graph_name,
|
|
"node_id": node_id,
|
|
"depth": hops
|
|
})
|
|
|
|
if "error" not in result:
|
|
for neighbor in result.get("neighbors", []):
|
|
nid = neighbor.get("node_id", "")
|
|
if nid and nid not in seen:
|
|
seen.add(nid)
|
|
neighbor["source_entity"] = entity_id
|
|
neighbor["id"] = nid
|
|
all_neighbors.append(neighbor)
|
|
|
|
return all_neighbors
|
|
|
|
|
|
def hybrid_retrieve(query_embedding: List[float],
|
|
extracted_entities: List[str],
|
|
pipeline: Dict,
|
|
collection: str = "knowledge",
|
|
graph_name: str = "knowledge") -> List[Dict]:
|
|
"""Hybrid retrieval: vector + graph + RRF fusion."""
|
|
plugins = pipeline.get("plugins", {})
|
|
retriever_config = plugins.get("retriever", {})
|
|
|
|
vector_top_k = retriever_config.get("vector_top_k", 20)
|
|
graph_hops = retriever_config.get("graph_hops", 2)
|
|
retriever_type = retriever_config.get("type", "hybrid")
|
|
|
|
results_lists = []
|
|
|
|
# Vector search
|
|
vdb_plugin = plugins.get("vdb", {})
|
|
vector_results = vector_search(query_embedding, vdb_plugin, collection, vector_top_k)
|
|
if vector_results:
|
|
results_lists.append(vector_results)
|
|
|
|
# Graph expansion (only if hybrid mode and graph enabled)
|
|
if retriever_type == "hybrid" and extracted_entities:
|
|
graph_plugin = plugins.get("graph", {})
|
|
graph_results = graph_expand(extracted_entities, graph_plugin, graph_name, graph_hops)
|
|
if graph_results:
|
|
results_lists.append(graph_results)
|
|
|
|
if not results_lists:
|
|
return []
|
|
|
|
if len(results_lists) == 1:
|
|
return results_lists[0]
|
|
|
|
return rrf_fusion(results_lists)
|
|
|
|
|
|
def rerank_results(query: str, results: List[Dict], reranker_plugin: Dict,
|
|
top_k: int = 5) -> List[Dict]:
|
|
"""Rerank results using BGE-Reranker."""
|
|
from plugins.registry import call_plugin
|
|
|
|
if not results or reranker_plugin.get("type") == "none":
|
|
return results[:top_k]
|
|
|
|
endpoint = reranker_plugin.get("endpoint")
|
|
if not endpoint:
|
|
return results[:top_k]
|
|
|
|
# Prepare documents for reranking
|
|
documents = []
|
|
for r in results:
|
|
doc = r.get("text", "") or r.get("description", "") or r.get("transcript", "")
|
|
if not doc:
|
|
doc = str(r.get("name", "")) + " " + str(r.get("attrs", {}))
|
|
documents.append(doc)
|
|
|
|
if not documents:
|
|
return results[:top_k]
|
|
|
|
rerank_result = call_plugin(endpoint, "/api/rerank", {
|
|
"query": query,
|
|
"documents": documents,
|
|
"top_k": top_k
|
|
})
|
|
|
|
if "error" in rerank_result:
|
|
return results[:top_k]
|
|
|
|
# Map reranked docs back to original results
|
|
ranked = []
|
|
for rd in rerank_result.get("ranked_docs", []):
|
|
doc_text = rd.get("doc", "")
|
|
for orig in results:
|
|
orig_text = orig.get("text", "") or orig.get("description", "") or str(orig.get("name", ""))
|
|
if doc_text == orig_text or doc_text in orig_text or orig_text in doc_text:
|
|
orig["rerank_score"] = rd.get("score", 0)
|
|
ranked.append(orig)
|
|
break
|
|
|
|
return ranked if ranked else results[:top_k]
|