# -*- coding:utf-8 -*- """RAG Pipeline orchestration - ingest and search workflows.""" import json import time import uuid from typing import Dict, List, Optional from traceback import format_exc from plugins.registry import get_pipeline, call_plugin from core.chunker import chunk_document from core.extractor import extract_entities_relations from core.retriever import hybrid_retrieve, rerank_results def _get_embedding(texts: List[str], embedding_plugin: Dict) -> List[List[float]]: """Get text embeddings from CLIP service.""" endpoint = embedding_plugin.get("endpoint") if not endpoint: return [] result = call_plugin(endpoint, "/api/text", {"texts": texts}) if "error" in result: return [] return result.get("embeddings", []) def _ensure_collection(vdb_plugin: Dict, collection: str, dim: int = 1024) -> Dict: """Ensure VDB collection exists with correct schema.""" endpoint = vdb_plugin.get("endpoint") if not endpoint: return {"error": "VDB endpoint not configured"} # Check if collection exists list_result = call_plugin(endpoint, "/v1/listcollections", {}) existing = list_result.get("collections", []) if collection in existing: return {"status": "exists"} # Create with proper field schema fields = [ {"name": "id", "type": "str", "is_primary": True, "max_length": 256}, {"name": "text", "type": "str", "max_length": 65535}, {"name": "description", "type": "str", "max_length": 65535}, {"name": "type", "type": "str", "max_length": 128}, {"name": "embedding", "type": "fvector", "dim": dim} ] result = call_plugin(endpoint, "/v1/createcollection", { "colname": collection, "fields": fields, "description": "RAG knowledge collection", "metric": "COSINE" }) return result def _store_in_vdb(items: List[Dict], embeddings: List[List[float]], vdb_plugin: Dict, collection: str = "knowledge") -> Dict: """Store items with embeddings in VDB.""" endpoint = vdb_plugin.get("endpoint") if not endpoint: return {"error": "VDB endpoint not configured"} # Ensure collection exists ensure_result = _ensure_collection(vdb_plugin, collection) if "error" in ensure_result: return ensure_result # Prepare batch insert - data must match schema fields rows = [] for item, emb in zip(items, embeddings): row = { "id": item.get("id", str(uuid.uuid4())[:16]), "text": item.get("text", ""), "description": item.get("description", ""), "type": item.get("type", "chunk"), "embedding": emb } rows.append(row) result = call_plugin(endpoint, "/v1/batchinsert", { "colname": collection, "data": rows }) return result def _store_in_graph(entities: List[Dict], relations: List[Dict], graph_plugin: Dict, graph_name: str = "knowledge") -> Dict: """Store entities and relations in graph.""" endpoint = graph_plugin.get("endpoint") if not endpoint or graph_plugin.get("type") == "none": return {"status": "skipped", "reason": "graph disabled"} added_nodes = 0 added_edges = 0 for entity in entities: name = entity.get("name", "") node_id = f"entity_{name}".replace(" ", "_") result = call_plugin(endpoint, "/api/graph/add_node", { "graph": graph_name, "node_id": node_id, "attrs": { "name": name, "type": entity.get("type", "unknown"), "description": entity.get("description", "") } }) if "error" not in result: added_nodes += 1 for rel in relations: source = f"entity_{rel.get('source', '')}".replace(" ", "_") target = f"entity_{rel.get('target', '')}".replace(" ", "_") result = call_plugin(endpoint, "/api/graph/add_edge", { "graph": graph_name, "source": source, "target": target, "attrs": { "relation": rel.get("relation", "related_to"), "description": rel.get("description", "") } }) if "error" not in result: added_edges += 1 # Save graph to disk call_plugin(endpoint, "/api/graph/save", {"graph": graph_name}) return {"status": "ok", "nodes_added": added_nodes, "edges_added": added_edges} def ingest(document: str, pipeline_name: str = "kg-rag-standard", collection: str = "knowledge", graph_name: str = "knowledge", llm_func=None) -> Dict: """Full ingest pipeline: chunk -> embed -> extract -> store.""" start_time = time.time() pipeline = get_pipeline(pipeline_name) plugins = pipeline.get("plugins", {}) # Step 1: Chunk chunker_config = plugins.get("chunker", {}) chunks = chunk_document(document, strategy=chunker_config.get("type", "recursive"), chunk_size=chunker_config.get("chunk_size", 512), overlap=chunker_config.get("overlap", 64)) chunk_texts = [c["text"] for c in chunks] # Step 2: Embed embedding_plugin = plugins.get("embedding", {}) embeddings = _get_embedding(chunk_texts, embedding_plugin) if not embeddings: return {"error": "Failed to get embeddings", "chunks": len(chunks)} # Step 3: Store chunks + embeddings in VDB vdb_plugin = plugins.get("vdb", {}) vdb_result = _store_in_vdb(chunks, embeddings, vdb_plugin, collection) # Step 4: Extract entities and relations extractor_config = plugins.get("extractor", {}) extraction = {"entities": [], "relations": []} if extractor_config.get("type") != "none" and llm_func: extraction = extract_entities_relations(document[:3000], llm_func) # Step 5: Store in graph graph_plugin = plugins.get("graph", {}) graph_result = _store_in_graph( extraction["entities"], extraction["relations"], graph_plugin, graph_name ) else: graph_result = {"status": "skipped"} elapsed = round(time.time() - start_time, 3) return { "status": "SUCCEEDED", "pipeline": pipeline_name, "chunks": len(chunks), "embeddings": len(embeddings), "entities": len(extraction.get("entities", [])), "relations": len(extraction.get("relations", [])), "vdb_result": vdb_result, "graph_result": graph_result, "elapsed": elapsed } def search(query: str, pipeline_name: str = "kg-rag-standard", collection: str = "knowledge", graph_name: str = "knowledge", top_k: int = 5, llm_func=None) -> Dict: """Full search pipeline: embed -> retrieve -> rerank -> generate.""" start_time = time.time() pipeline = get_pipeline(pipeline_name) plugins = pipeline.get("plugins", {}) # Step 1: Embed query embedding_plugin = plugins.get("embedding", {}) query_embeddings = _get_embedding([query], embedding_plugin) if not query_embeddings: return {"error": "Failed to embed query"} query_embedding = query_embeddings[0] # Step 2: Extract entities from query (for graph expansion) extracted_entities = [] if llm_func: quick_extraction = extract_entities_relations(query, llm_func) extracted_entities = [e["name"] for e in quick_extraction.get("entities", [])] # Step 3: Hybrid retrieval results = hybrid_retrieve( query_embedding, extracted_entities, pipeline, collection, graph_name ) # Step 4: Rerank reranker_plugin = plugins.get("reranker", {}) ranked_results = rerank_results(query, results, reranker_plugin, top_k) # Step 5: Generate answer (optional) answer = None if llm_func and ranked_results: context = "\n\n".join([ r.get("text", "") or r.get("description", "") or str(r) for r in ranked_results[:top_k] ]) gen_prompt = f"""基于以下上下文回答问题。如果上下文中没有相关信息,请如实说明。 上下文: {context[:2000]} 问题:{query} 请用中文回答:""" try: answer = llm_func(gen_prompt) except: answer = None elapsed = round(time.time() - start_time, 3) return { "status": "SUCCEEDED", "pipeline": pipeline_name, "query": query, "results": ranked_results[:top_k], "total_retrieved": len(results), "answer": answer, "extracted_entities": extracted_entities, "elapsed": elapsed }