rag-pipeline/plugins/registry.py

130 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding:utf-8 -*-
"""Plugin registry - manages available RAG plugins and pipeline configs."""
import json
import os
import urllib.request
from typing import Dict, List, Optional
CONFIG_DIR = "/data/ymq/rag-pipeline/pipelines"
# Default pipeline config
DEFAULT_PIPELINE = {
"name": "kg-rag-standard",
"description": "标准知识图谱RAG",
"plugins": {
"embedding": {"type": "clip-vith14", "endpoint": "http://localhost:9086"},
"vdb": {"type": "vdb-milvus", "endpoint": "http://localhost:8886"},
"graph": {"type": "networkx", "endpoint": "http://localhost:9092"},
"llm": {"type": "harnessed", "endpoint": "internal"},
"reranker": {"type": "bge-reranker", "endpoint": "http://localhost:9090"},
"chunker": {"type": "recursive", "chunk_size": 512, "overlap": 64},
"extractor": {"type": "llm-structured"},
"retriever": {"type": "hybrid", "vector_top_k": 20, "graph_hops": 2}
}
}
LITE_PIPELINE = {
"name": "kg-rag-lite",
"description": "轻量版RAG纯向量无图谱",
"plugins": {
"embedding": {"type": "clip-vith14", "endpoint": "http://localhost:9086"},
"vdb": {"type": "vdb-milvus", "endpoint": "http://localhost:8886"},
"graph": {"type": "none"},
"llm": {"type": "harnessed", "endpoint": "internal"},
"reranker": {"type": "bge-reranker", "endpoint": "http://localhost:9090"},
"chunker": {"type": "recursive", "chunk_size": 512, "overlap": 64},
"retriever": {"type": "vector_only", "top_k": 10}
}
}
def _ensure_config_dir():
os.makedirs(CONFIG_DIR, exist_ok=True)
def list_pipelines() -> List[Dict]:
"""List all available pipeline configs."""
_ensure_config_dir()
pipelines = [DEFAULT_PIPELINE, LITE_PIPELINE]
# Load custom pipelines from disk
for f in os.listdir(CONFIG_DIR):
if f.endswith('.json'):
try:
with open(os.path.join(CONFIG_DIR, f), 'r') as fh:
pipelines.append(json.load(fh))
except:
pass
return pipelines
def get_pipeline(name: str) -> Optional[Dict]:
"""Get a pipeline config by name."""
for p in list_pipelines():
if p["name"] == name:
return p
return DEFAULT_PIPELINE
def save_pipeline(config: Dict) -> Dict:
"""Save a custom pipeline config."""
_ensure_config_dir()
filepath = os.path.join(CONFIG_DIR, f"{config['name']}.json")
with open(filepath, 'w') as f:
json.dump(config, f, ensure_ascii=False, indent=2)
return {"status": "ok", "filepath": filepath}
def list_plugins() -> Dict:
"""List all available plugins by capability."""
return {
"embedding": [
{"type": "clip-vith14", "model": "BAAI/CLIP-ViT-H-14", "dim": 1024, "endpoint": "http://localhost:9086", "status": "available"},
{"type": "bge-m3", "model": "BAAI/bge-m3", "dim": 1024, "status": "not_deployed"},
],
"vdb": [
{"type": "vdb-milvus", "backend": "Milvus", "endpoint": "http://localhost:8886", "status": "available"},
{"type": "qdrant", "backend": "Qdrant", "status": "not_deployed"},
],
"graph": [
{"type": "networkx", "backend": "NetworkX", "endpoint": "http://localhost:9092", "status": "available"},
{"type": "falkordb", "backend": "FalkorDB+Redis", "status": "blocked_redis_module"},
{"type": "none", "description": "Disable graph"},
],
"llm": [
{"type": "harnessed", "description": "harnessed_agent", "status": "available"},
],
"reranker": [
{"type": "bge-reranker", "model": "BAAI/bge-reranker-v2-m3", "endpoint": "http://localhost:9090", "status": "available"},
{"type": "none", "description": "Skip reranking"},
],
"chunker": [
{"type": "recursive", "description": "Recursive character splitter"},
{"type": "sentence", "description": "Sentence-based splitter"},
],
"extractor": [
{"type": "llm-structured", "description": "LLM-based entity/relation extraction"},
{"type": "none", "description": "Skip extraction"},
],
"retriever": [
{"type": "hybrid", "description": "Vector + Graph hybrid retrieval"},
{"type": "vector_only", "description": "Pure vector retrieval"},
],
"face": [
{"type": "insightface", "model": "buffalo_l", "dim": 512, "status": "available", "endpoint": "http://localhost:9091"},
]
}
def call_plugin(endpoint: str, path: str, data: Dict, timeout: int = 30) -> Dict:
"""Call a plugin endpoint via HTTP POST."""
url = f"{endpoint}{path}"
payload = json.dumps(data).encode('utf-8')
req = urllib.request.Request(url, data=payload, headers={'Content-Type': 'application/json'})
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode('utf-8'))
except Exception as e:
return {"error": str(e)}