From 1d9e2f3efc6e4f4020f7bf50bf78c059f6b71517 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Sun, 14 Jun 2026 20:14:59 +0800 Subject: [PATCH] feat: v2.0 - standalone support, COSINE metric, batch ops, fix query bug, new APIs --- .gitignore | 8 + README.md | 327 +++++++++++++------------------ build.sh | 107 +++++++---- conf/config.json | 37 +++- pyproject.toml | 12 +- vdb/basevdb.py | 60 ++++-- vdb/init.py | 251 ++++++++++++++---------- vdb/milvus.py | 487 ++++++++++++++++++++++++++++------------------- 8 files changed, 712 insertions(+), 577 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2578641 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +ah.pid +nohup.out +__pycache__/ +*.pyc +db/ +logs/ +files/ +wwwroot/ diff --git a/README.md b/README.md index 2a6d356..a946d04 100644 --- a/README.md +++ b/README.md @@ -1,208 +1,137 @@ -# vdb +# VDB - Vector Database Service -## 安装 +基于 ahserver + pymilvus 的向量数据库 HTTP 服务,支持 Milvus Lite 和 Milvus Standalone 两种模式。 -执行 +## 特性 + +- 支持 Milvus Lite(本地文件)和 Milvus Standalone(Docker 部署,千万级) +- COSINE / IP / L2 三种距离度量 +- HNSW / IVF_FLAT / IVF_PQ 索引类型可配 +- 批量插入(batch_insert)和批量搜索(batch_query) +- 向量搜索 + 标量过滤组合查询 +- 集合管理(list/stats/drop) + +## 部署 + +```bash +cd /data/ymq/vdb +bash build.sh deploy # 启动 +bash build.sh stop # 停止 +bash build.sh status # 状态 ``` -git clone https://git.opencomputing.cn/yumoqing/vdb.git -cd vdb -bash ./build.sh -``` -执行完成后: -1 添加了一个vdb.service, 重启后会自动启动向量数据库服务 -2 使用sudo systemctl XXX vdb.service开启动,停止, 重启服务 -3 app/vdbapp.py 正在执行 +### 切换到 Milvus Standalone(千万级) + +编辑 `conf/config.json`: +```json +{ + "milvus_mode": "standalone", + "milvus_host": "127.0.0.1:19530" +} +``` + +部署 Milvus Standalone(Docker): +```yaml +# docker-compose.yml +services: + etcd: + image: quay.io/coreos/etcd:v3.5.18 + environment: + - ETCD_AUTO_COMPACTION_MODE=revision + - ETCD_AUTO_COMPACTION_RETENTION=1000 + volumes: ["etcd:/etcd"] + command: etcd --advertise-client-urls=http://127.0.0.1:2379 --listen-client-urls=http://0.0.0.0:2379 --data-dir=/etcd + + minio: + image: minio/minio:latest + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + volumes: ["minio:/minio_data"] + command: minio server /minio_data --console-address ":9001" + + milvus: + image: milvusdb/milvus:v2.4-latest + depends_on: [etcd, minio] + ports: ["19530:19530", "9091:9091"] + volumes: ["milvus:/var/lib/milvus"] + environment: + ETCD_ENDPOINTS: etcd:2379 + MINIO_ADDRESS: minio:9000 +``` + +## API 接口 + +### GET /v1/listcollections +列出所有集合。 + +### POST /v1/createcollection +创建集合。 +```json +{ + "colname": "entities", + "fields": [ + {"name": "id", "type": "str", "max_length": 32, "is_primary": true}, + {"name": "embedding", "type": "fvector", "dim": 1024} + ], + "metric": "COSINE", + "index_type": "HNSW" +} +``` + +### POST /v1/collectionstats +获取集合统计信息(行数、字段、索引)。 + +### POST /v1/upsert +插入或更新(自动删除旧记录再插入,适合少量)。 + +### POST /v1/batchinsert +批量插入(不逐条flush,适合大批量入库)。 +```json +{"colname": "entities", "data": [...], "flush": true} +``` + +### POST /v1/query +向量搜索 + 标量过滤 + 分页。 +```json +{ + "colname": "entities", + "vector": [0.1, 0.2, ...], + "expr": "type == \"animal\"", + "pagerows": 20, + "page": 1, + "metric": "COSINE" +} +``` + +### POST /v1/batchquery +批量向量搜索(多个查询向量同时搜索)。 + +### POST /v1/delete +按主键删除。 + +### POST /v1/dropcollection +删除整个集合。 ## 字段类型 -以下字段类型可用在集合的话数据类型中 -* "str": 可变长字符串, 需要max_length属性 -* "int": 整数类型 -* "bool": 逻辑值类型 -* "float": 浮点数 -* "fvector": 浮点数向量,常规向量都用此类型, -* "bvector": 二进制向量 二分向量使用, -* "json": json格式的数据 +| 类型 | 说明 | 备注 | +|------|------|------| +| str | 字符串 | 需要 max_length | +| int | 32位整数 | | +| int64 | 64位整数 | | +| bool | 布尔值 | | +| float | 浮点数 | | +| fvector | 浮点向量 | 需要 dim | +| bvector | 二进制向量 | 需要 dim | +| json | JSON | 灵活结构 | -一些例子 -``` -[ - # 1. 主键字段 (必选): str 类型,非自增(手动指定ID以便与业务系统关联) - { - "name": "id", - "type": "str", - "max_length": 32 - "is_primary": True, - "auto_id": False - }, +## 配置 (conf/config.json) - # 2. 向量字段 (核心): 存储 CLIP 提取的特征,768 维 - { - "name": "video_embedding", - "type": "fvector", - "dim": 768 - }, - - # 3. 变长字符串: 存储视频存储路径,需指定最大长度 - { - "name": "file_path", - "type": "str", - "max_length": 512 - }, - - # 4. 浮点数: 存储评估得分 (如 VBench 的综合分数) - { - "name": "quality_score", - "type": "float" - }, - - # 5. 布尔值: 标记是否已完成人工复核 - { - "name": "is_reviewed", - "type": "bool" - }, - - # 6. 整数: 存储视频的时长(秒) - { - "name": "duration_sec", - "type": "int" - }, - - # 7. JSON 字段: 存储非结构化的元数据(如 VBench 的 16 个子维度细节) - # Milvus 2.4+ 支持 JSON 动态解析查询 - { - "name": "meta_data", - "type": "json" - } -] -``` - -## API -对外提供http的接口,用nginx做反向代理,提供安全,并在nginx.conf中提供客户端ip过滤,只有登记过的客户端才能使用 - -### 创建集合 - -path:/v1/createcollection -method:POST -headers:{ - "Content-Type": "application/json" -} -data:{ - colname: 集合名字,必须提供 - fields: 集合字段集,请参照字段类型提供 - description:可选, 集合描述 -} - -成功返回 -{ - "status":"SUCCEEDED" -} - -失败返回 -{ - "status":"FAILED", - "error": 错误信息 -} - -### 删除集合 - -path:/v1/dropcollection -method:POST -headers:{ - "Content-Type": "application/json" -} -data:{ - colname: 集合名字,必须提供 -} - -成功返回 -{ - "status":"SUCCEEDED" -} - -失败返回 -{ - "status":"FAILED", - "error": 错误信息 -} - -### 向集合插入一到多条记录 - -path:/v1/upsert -method:POST -headers:{ - "Content-Type": "application/json" -} -data:{ - colname: 集合名字,必须提供 - data: 数据字典或数据字典数组 -} - -成功返回 -{ - "status":"SUCCEEDED" -} - -失败返回 -{ - "status":"FAILED", - "error": 错误信息 -} - -### 删除集合一到多条记录 - -path:/v1/delete -method:POST -headers:{ - "Content-Type": "application/json" -} -data:{ - colname: 集合名字,必须提供 - pks: 主键或主键数组 - description:可选, 集合描述 -} - -成功返回 -{ - "status":"SUCCEEDED" -} - -失败返回 -{ - "status":"FAILED", - "error": 错误信息 -} - -### 查询集合数据 - -path:/v1/createcollection -method:POST -headers:{ - "Content-Type": "application/json" -} -data:{ - colname: 集合名字,必须提供 - fields: 集合字段集,请参照字段类型提供 - description:可选, 集合描述 -} - -成功返回 -{ - "status":"SUCCEEDED" - "data": 返回数据 -} - -失败返回 -{ - "status":"FAILED", - "error": 错误信息 -} - -返回数据有如下结构 -{ - "total": -1, # 不知道总共多少条符合条件的数据 - "page": 当前页(起始1) - "pagerows": 每页记录数 - "rows": 记录数据 -} +| 字段 | 默认值 | 说明 | +|------|--------|------| +| milvus_mode | lite | lite 或 standalone | +| milvus_db | $[workdir]$/db/milvus.db | lite模式文件路径 | +| milvus_host | 127.0.0.1:19530 | standalone模式地址 | +| milvus_metric | COSINE | 默认距离度量 | +| client_max_size | 100MB | 请求体大小上限 | diff --git a/build.sh b/build.sh index 43f46b1..d2e1300 100755 --- a/build.sh +++ b/build.sh @@ -1,45 +1,68 @@ #!/usr/bin/env bash -# clone from git@git.opencomputing.cn/yumoqing/vdb -# git clone https://git.opencomputing.cn/yumoqing/vdb -cdir=$(pwd) -uname=$(id -un) -gname=$(id -gn) -sudo apt install redis-server -python3 -m venv py3 -source py3/bin/activate -pip install . -mkdir $cdir/logs -cd $cdir -cat > $cdir/vdb.service < $cdir/start.sh < $cdir/stop.sh </dev/null; then + kill $(cat ah.pid) 2>/dev/null || true + sleep 2 + fi + if [ -d .git ]; then + git pull origin master 2>/dev/null || git pull origin main 2>/dev/null || true + fi + mkdir -p logs db + export PYTHONPATH="$(pwd)" + nohup $PY app/vdbapp.py > nohup.out 2>&1 & + echo $! > ah.pid + echo "Started PID $(cat ah.pid) on port $PORT" + sleep 3 + if curl -s http://localhost:$PORT/v1/listcollections > /dev/null 2>&1; then + echo "Service healthy" + else + echo "WARNING: not responding yet, check nohup.out" + tail -20 nohup.out + fi + ;; + stop) + if [ -f ah.pid ]; then + kill $(cat ah.pid) 2>/dev/null || true + rm -f ah.pid + echo "Stopped" + else + echo "Not running" + fi + ;; + start) + mkdir -p logs db + export PYTHONPATH="$(pwd)" + nohup $PY app/vdbapp.py > nohup.out 2>&1 & + echo $! > ah.pid + echo "Started PID $(cat ah.pid)" + ;; + status) + echo "=== $SERVICE_NAME Status ===" + if [ -f ah.pid ] && kill -0 $(cat ah.pid) 2>/dev/null; then + echo "Process: running (PID $(cat ah.pid))" + else + echo "Process: not running" + fi + echo "Port: $PORT" + if curl -s --max-time 3 http://localhost:$PORT/v1/listcollections > /dev/null 2>&1; then + echo "HTTP: OK" + else + echo "HTTP: not responding" + fi + ;; + *) + echo "Usage: $0 {deploy|update|stop|start|status}" + exit 1 + ;; +esac diff --git a/conf/config.json b/conf/config.json index 250c5f2..f939bbe 100644 --- a/conf/config.json +++ b/conf/config.json @@ -1,7 +1,12 @@ { - "vdb_type": "milvus", - "filesroot": "$[workdir]$/files", + "vdb_type": "milvus", + "milvus_mode": "lite", "milvus_db": "$[workdir]$/db/milvus.db", + "milvus_host": "127.0.0.1:19530", + "milvus_token": "", + "milvus_dbname": "default", + "milvus_metric": "COSINE", + "filesroot": "$[workdir]$/files", "logger": { "name": "vdb", "levelname": "info", @@ -11,13 +16,13 @@ "paths": [ ["$[workdir]$/wwwroot", ""] ], - "client_max_size": 10000, + "client_max_size": 104857600, "host": "0.0.0.0", "port": 8886, "coding": "utf-8", "indexes": [ "index.html", - "index.dspy", + "index.dspy", "index.ui" ], "startswiths": [ @@ -34,16 +39,32 @@ "registerfunction": "drop_collection" }, { - "leading": "/v1/iupsert", + "leading": "/v1/listcollections", + "registerfunction": "list_collections" + }, + { + "leading": "/v1/collectionstats", + "registerfunction": "collection_stats" + }, + { + "leading": "/v1/upsert", "registerfunction": "upsert" }, + { + "leading": "/v1/batchinsert", + "registerfunction": "batch_insert" + }, { "leading": "/v1/delete", "registerfunction": "delete" }, { - "leading": "/v1/query", - "registerfunction": "query" + "leading": "/v1/query", + "registerfunction": "query" + }, + { + "leading": "/v1/batchquery", + "registerfunction": "batch_query" }, { "leading": "/docs", @@ -57,5 +78,5 @@ [".dspy", "dspy"], [".md", "md"] ] - } + } } diff --git a/pyproject.toml b/pyproject.toml index d443478..87c5dac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,18 +4,18 @@ build-backend = "setuptools.build_meta" [project] name = "vdb" -version = "0.1.0" -description = "向量数据库服务" +version = "0.2.0" +description = "Vector Database Service (Milvus Lite + Standalone)" readme = "README.md" requires-python = ">=3.10" dependencies = [ "apppublic", "sqlor", "ahserver", - "milvus-lite", - "pymilvus" + "milvus-lite", + "pymilvus" ] [tool.setuptools.packages.find] -where = ["."] # 声明在哪个目录下查找包,默认是当前目录 -include = ["vdb"] # 包含哪些包 +where = ["."] +include = ["vdb"] diff --git a/vdb/basevdb.py b/vdb/basevdb.py index 70127d3..37c46af 100644 --- a/vdb/basevdb.py +++ b/vdb/basevdb.py @@ -1,24 +1,44 @@ +from abc import ABC, abstractmethod -class BaseVDB: - def query(self, collection_name, - vector=None, - expr=None, - pagerows=80, - page=1, - output_fields=None): - pass - def get_db_type(dtype: str): - return self.dbtypes.get(dtype) +class BaseVDB(ABC): + """向量数据库抽象基类""" - def upsert(self, collection_name, data_dicts, partition_name=None): - pass + @abstractmethod + def create_collection(self, collection_name, fields_config, description="", **kwargs): + pass - def delete(self, collection_name, pks, partition_name=None): - pass - - def drop_collection(self, collection_name): - pass - - def create_collection(self, collection_name, fields_config, description=""): - pass + @abstractmethod + def drop_collection(self, collection_name): + pass + + @abstractmethod + def list_collections(self): + pass + + @abstractmethod + def collection_stats(self, collection_name): + pass + + @abstractmethod + def upsert(self, collection_name, data_dicts, **kwargs): + pass + + @abstractmethod + def batch_insert(self, collection_name, data_dicts, **kwargs): + pass + + @abstractmethod + def delete(self, collection_name, pks, **kwargs): + pass + + @abstractmethod + def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None, **kwargs): + pass + + @abstractmethod + def batch_query(self, collection_name, vectors, expr=None, limit=10, output_fields=None, **kwargs): + pass + + def get_db_type(self, dtype: str): + return self.dbtypes.get(dtype) if hasattr(self, "dbtypes") else None diff --git a/vdb/init.py b/vdb/init.py index 07ad2b5..3d9af09 100644 --- a/vdb/init.py +++ b/vdb/init.py @@ -1,5 +1,4 @@ from traceback import format_exc -from functools import partial from ahserver.serverenv import ServerEnv from appPublic.worker import awaitify from appPublic.registerfunction import RegisterFunction @@ -7,118 +6,168 @@ from appPublic.log import debug, exception from appPublic.jsonConfig import getConfig from .milvus import MilvusManager + def ownerparting(data): - if data.get('ownerid'): - return data.get('ownerid') - return '_default' + if data.get("ownerid"): + return data.get("ownerid") + return "_default" + async def create_collection(request, params_kw, *args, **kwargs): - colname = params_kw.colname - fields = params_kw.fields - description = params_kw.description or "" - env = request._run_ns - f = awaitify(env.vdb.create_collection) - try: - r = await f(colname, fields, description) - return { - "status":"SUCCEEDED", - } - except Exception as e: - exception(f'{e}, {format_exc()}') - return { - "status":"FAILED", - "error": f"{e}" - } + colname = params_kw.colname + fields = params_kw.fields + description = params_kw.description or "" + metric = params_kw.metric + index_type = params_kw.index_type or "HNSW" + env = request._run_ns + f = awaitify(env.vdb.create_collection) + try: + await f(colname, fields, description, metric=metric, index_type=index_type) + return {"status": "SUCCEEDED"} + except Exception as e: + exception(f"{e}, {format_exc()}") + return {"status": "FAILED", "error": str(e)} + async def drop_collection(request, params_kw, *args, **kwargs): - colname = params_kw.colname - env = request._run_ns - f = awaitify(env.vdb.drop_collection) - try: - r = await f(colname) - return { - "status": "SUCCEEDED" - } - except Exception as e: - exception(f"{e}, {format_exc()}") - return { - "status": "FAILED", - "error": f"{e}" - } + colname = params_kw.colname + env = request._run_ns + f = awaitify(env.vdb.drop_collection) + try: + await f(colname) + return {"status": "SUCCEEDED"} + except Exception as e: + exception(f"{e}, {format_exc()}") + return {"status": "FAILED", "error": str(e)} + + +async def list_collections(request, params_kw, *args, **kwargs): + env = request._run_ns + f = awaitify(env.vdb.list_collections) + try: + result = await f() + return {"status": "SUCCEEDED", "collections": result} + except Exception as e: + exception(f"{e}, {format_exc()}") + return {"status": "FAILED", "error": str(e)} + + +async def collection_stats(request, params_kw, *args, **kwargs): + colname = params_kw.colname + env = request._run_ns + f = awaitify(env.vdb.collection_stats) + try: + result = await f(colname) + if result is None: + return {"status": "FAILED", "error": f"Collection '{colname}' not found"} + return {"status": "SUCCEEDED", "data": result} + except Exception as e: + exception(f"{e}, {format_exc()}") + return {"status": "FAILED", "error": str(e)} + async def upsert(request, params_kw, *args, **kwargs): - colname = params_kw.colname - data = params_kw.data - if not isinstance(data, list): - data = [data] - - env = request._run_ns - f = awaitify(env.vdb.upsert) - try: - r = await f(colname, data) - return { - "status": "SUCCEEDED" - } - except Exception as e: - exception(f"{e}, {format_exc()}") - return { - "status": "FAILED", - "error": f"{e}" - } + colname = params_kw.colname + data = params_kw.data + if not isinstance(data, list): + data = [data] + env = request._run_ns + f = awaitify(env.vdb.upsert) + try: + await f(colname, data) + return {"status": "SUCCEEDED"} + except Exception as e: + exception(f"{e}, {format_exc()}") + return {"status": "FAILED", "error": str(e)} + + +async def batch_insert(request, params_kw, *args, **kwargs): + colname = params_kw.colname + data = params_kw.data + flush = params_kw.flush if params_kw.flush is not None else True + partition_name = params_kw.partition_name + if not isinstance(data, list): + return {"status": "FAILED", "error": "data must be a list"} + env = request._run_ns + f = awaitify(env.vdb.batch_insert) + try: + await f(colname, data, flush=flush, partition_name=partition_name) + return {"status": "SUCCEEDED", "count": len(data)} + except Exception as e: + exception(f"{e}, {format_exc()}") + return {"status": "FAILED", "error": str(e)} + async def delete(request, params_kw, *args, **kwargs): - colname = params_kw.colname - pks = params_kw.pks + colname = params_kw.colname + pks = params_kw.pks + env = request._run_ns + f = awaitify(env.vdb.delete) + try: + await f(colname, pks) + return {"status": "SUCCEEDED"} + except Exception as e: + exception(f"{e}, {format_exc()}") + return {"status": "FAILED", "error": str(e)} - env = request._run_ns - f = awaitify(env.vdb.delete) - try: - r = await f(colname, pks) - return { - "status": "SUCCEEDED" - } - except Exception as e: - exception(f"{e}, {format_exc()}") - return { - "status": "FAILED", - "error": f"{e}" - } async def query(request, params_kw, *args, **kwargs): - colname = params_kw.colname - vector = params_kw.vector - expr = params_kw.expr - pagerows = params_kw.pagerows or 80 - page = params_kw.page or 1 - output_fields = params_kw.output_fields + colname = params_kw.colname + vector = params_kw.vector + expr = params_kw.expr + pagerows = params_kw.pagerows or 80 + page = params_kw.page or 1 + output_fields = params_kw.output_fields + metric = params_kw.metric + + env = request._run_ns + f = awaitify(env.vdb.query) + try: + result = await f(colname, vector=vector, expr=expr, pagerows=pagerows, + page=page, output_fields=output_fields, metric=metric) + return {"status": "SUCCEEDED", "data": result} + except Exception as e: + exception(f"{e}, {format_exc()}") + return {"status": "FAILED", "error": str(e)} + + +async def batch_query(request, params_kw, *args, **kwargs): + colname = params_kw.colname + vectors = params_kw.vectors + expr = params_kw.expr + limit = params_kw.limit or 10 + output_fields = params_kw.output_fields + metric = params_kw.metric + + if not isinstance(vectors, list) or not vectors: + return {"status": "FAILED", "error": "vectors must be a non-empty list"} + + env = request._run_ns + f = awaitify(env.vdb.batch_query) + try: + result = await f(colname, vectors, expr=expr, limit=limit, + output_fields=output_fields, metric=metric) + return {"status": "SUCCEEDED", "data": result} + except Exception as e: + exception(f"{e}, {format_exc()}") + return {"status": "FAILED", "error": str(e)} - env = request._run_ns - f1 = awaitify(env.vdb.drop_collection) - try: - f = partial(f1, colname, vector=vector, expr=expr, pagerows=pagerows, page=page, output_fields=output_fields) - r = await f() - return { - "status": "SUCCEEDED", - "data": r - } - except Exception as e: - exception(f"{e}, {format_exc()}") - return { - "status": "FAILED", - "error": f"{e}" - } def load_vdb(): - config = getConfig() - vdb = None - vdb_type = config.vdb_type - if vdb_type == 'milvus': - vdb = MilvusManager(partitionize=ownerparting) - env = ServerEnv() - env.vdb = vdb - rf = RegisterFunction() - rf.register('create_collection', create_collection) - rf.register('drop_collection', drop_collection) - rf.register('upsert', upsert) - rf.register('delete', delete) - rf.register('query', query) + config = getConfig() + vdb = None + vdb_type = config.vdb_type + if vdb_type == "milvus": + vdb = MilvusManager(partitionize=ownerparting) + env = ServerEnv() + env.vdb = vdb + rf = RegisterFunction() + rf.register("create_collection", create_collection) + rf.register("drop_collection", drop_collection) + rf.register("list_collections", list_collections) + rf.register("collection_stats", collection_stats) + rf.register("upsert", upsert) + rf.register("batch_insert", batch_insert) + rf.register("delete", delete) + rf.register("query", query) + rf.register("batch_query", batch_query) diff --git a/vdb/milvus.py b/vdb/milvus.py index 4cbb805..515b064 100644 --- a/vdb/milvus.py +++ b/vdb/milvus.py @@ -4,224 +4,309 @@ from appPublic.log import info, debug, exception, error from threading import Lock from pymilvus import ( - connections, FieldSchema, CollectionSchema, - DataType, Collection, utility, Partition + connections, FieldSchema, CollectionSchema, + DataType, Collection, utility ) from .basevdb import BaseVDB + class MilvusManager(BaseVDB): - _instance = None - _lock = Lock() - dbtypes = { - "str": DataType.VARCHAR, - "int": DataType.INT32, - "bool": DataType.BOOL, - "float": DataType.FLOAT, - "fvector": DataType.FLOAT_VECTOR, - "bvector": DataType.BINARY_VECTOR, - "json": DataType.JSON - } - def __new__(cls, *args, **kwargs): - with cls._lock: - if cls._instance is None: - cls._instance = super(MilvusManager, cls).__new__(cls) - cls._instance._initialized = False - return cls._instance + _instance = None + _lock = Lock() + dbtypes = { + "str": DataType.VARCHAR, + "int": DataType.INT32, + "int64": DataType.INT64, + "bool": DataType.BOOL, + "float": DataType.FLOAT, + "fvector": DataType.FLOAT_VECTOR, + "bvector": DataType.BINARY_VECTOR, + "json": DataType.JSON, + } + metric_map = { + "L2": "L2", "l2": "L2", + "IP": "IP", "ip": "IP", + "COSINE": "COSINE", "cosine": "COSINE", + } - def __init__(self, partitionize=None): - self.partitionize = partitionize - if self._initialized: - return - try: - config = getConfig() - self.db_path = config.milvus_db - debug(f"dbpath: {self.db_path}") - except KeyError as e: - error(f"配置文件缺少必要字段: {str(e)}") - raise RuntimeError(f"配置文件缺少必要字段: {str(e)}") - self._initialize_connection() - self._initialized = True - info(f"MilvusDBConnection initialized with db_path: {self.db_path}") + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + cls._instance = super(MilvusManager, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance - def _initialize_connection(self): - """初始化 Milvus 连接,确保单一连接""" - try: - db_dir = os.path.dirname(self.db_path) - debug(f"db_dir: {db_dir}") - if not os.path.exists(db_dir): - os.makedirs(db_dir, exist_ok=True) - debug(f"创建 Milvus 目录: {db_dir}") - if not os.access(db_dir, os.W_OK): - raise RuntimeError(f"Milvus 目录 {db_dir} 不可写") - debug(f"不可写") - if not connections.has_connection(self.db_path): - connections.connect(self.db_path, uri=self.db_path) - debug(f"已连接到 Milvus Lite,路径: {self.db_path}") - else: - debug("已存在 Milvus 连接,跳过重复连接") - except Exception as e: - error(f"连接 Milvus 失败: {str(e)}") - raise RuntimeError(f"连接 Milvus 失败: {str(e)}") + def __init__(self, partitionize=None): + self.partitionize = partitionize + if self._initialized: + return + try: + config = getConfig() + self.alias = "default" + self.mode = getattr(config, "milvus_mode", "lite") + if self.mode == "standalone": + self.host = config.milvus_host + self.token = getattr(config, "milvus_token", "") + self.db_name = getattr(config, "milvus_dbname", "default") + else: + self.db_path = config.milvus_db + self.default_metric = getattr(config, "milvus_metric", "COSINE") + self._initialize_connection() + self._initialized = True + info(f"MilvusManager initialized: mode={self.mode}, alias={self.alias}") + except Exception as e: + error(f"MilvusManager init failed: {e}") + raise RuntimeError(f"MilvusManager init failed: {e}") - # --- 集合管理 --- - def create_collection(self, collection_name, fields_config, description=""): - """ - 打开或创建集合。如果已存在则返回对象;不存在则根据 config 创建。 - :param fields_config: 格式: [{"name": "id", "type": DataType.INT64, "is_primary": True}, ...] - """ - if utility.has_collection(collection_name, using=self.alias): - # print(f"📦 集合 '{collection_name}' 已存在,直接加载。") - return Collection(collection_name, using=self.alias) + def _initialize_connection(self): + try: + if self.mode == "standalone": + parts = self.host.split(":") + host = parts[0] + port = parts[1] if len(parts) > 1 else "19530" + uri = f"http://{host}:{port}" + if not connections.has_connection(self.alias): + kw = {"uri": uri, "alias": self.alias} + if self.token: + kw["token"] = self.token + if self.db_name and self.db_name != "default": + kw["db_name"] = self.db_name + connections.connect(**kw) + info(f"Connected to Milvus Standalone: {uri}") + else: + debug("Milvus Standalone connection already exists") + else: + db_dir = os.path.dirname(self.db_path) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + if not connections.has_connection(self.alias): + connections.connect(alias=self.alias, uri=self.db_path) + info(f"Connected to Milvus Lite: {self.db_path}") + else: + debug("Milvus Lite connection already exists") + except Exception as e: + error(f"Milvus connect failed: {e}") + raise RuntimeError(f"Milvus connect failed: {e}") - fields = [] - for cfg in fields_config: - field = FieldSchema( - name=cfg['name'], - dtype=self.get_db_type(cfg['type']), - is_primary=cfg.get('is_primary', False), - auto_id=cfg.get('auto_id', False), - dim=cfg.get('dim'), - max_length=cfg.get('max_length') - ) - fields.append(field) + # === Collection Management === - schema = CollectionSchema(fields, description=description) - collection = Collection(name=collection_name, schema=schema, using=self.alias) - - # 自动为向量字段创建索引 (内网推荐 HNSW) - for cfg in fields_config: - dbtype = self.get_db_type(cfg['type']) - if dbtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: - index_params = { - "metric_type": "L2", - "index_type": "HNSW", - "params": {"M": 16, "efConstruction": 64} - } - collection.create_index(field_name=cfg['name'], index_params=index_params) - - # print(f"🚀 集合 '{collection_name}' 创建并初始化索引完成。") - return collection + def list_collections(self): + return utility.list_collections(using=self.alias) - def drop_collection(self, collection_name): - """物理删除集合""" - if utility.has_collection(collection_name, using=self.alias): - utility.drop_collection(collection_name, using=self.alias) - print(f"🗑️ 集合 '{collection_name}' 已从磁盘删除。") + def collection_stats(self, collection_name): + if not utility.has_collection(collection_name, using=self.alias): + return None + col = Collection(collection_name, using=self.alias) + stats = {"name": collection_name, "row_count": col.num_entities} + stats["fields"] = [ + {"name": f.name, "type": str(f.dtype), "is_primary": f.is_primary} + for f in col.schema.fields + ] + try: + stats["indexes"] = [ + {"field": idx.field_name, "type": idx.params.get("index_type", "?")} + for idx in col.indexes + ] + except Exception: + stats["indexes"] = [] + return stats - # --- 内存与分区优化 --- - def load_segmented(self, collection_name, partition_names=None): - """分段加载:只加载特定分区到内存以保护内网服务器内存""" - col = Collection(collection_name, using=self.alias) - col.load(partition_names=partition_names) - print(f"🧠 已加载集合数据(分区: {partition_names if partition_names else '全量'})到内存。") + def create_collection(self, collection_name, fields_config, description="", + metric=None, index_type="HNSW", index_params_extra=None): + if utility.has_collection(collection_name, using=self.alias): + return Collection(collection_name, using=self.alias) - def release_memory(self, collection_name): - """释放内存:在执行 VBench 推理等高显存任务前调用""" - col = Collection(collection_name, using=self.alias) - col.release() - print(f"♻️ 集合 '{collection_name}' 已释放内存占用。") + metric = self.metric_map.get(metric, self.default_metric) if metric else self.default_metric - # --- 数据操作 (DML) --- - def _validate_and_format(self, col, data_dicts): - """私有工具:校验字典完整性并转换为列式数据""" - schema = col.schema - # 找出需要填充的字段名(排除自增ID) - required_fields = [f.name for f in schema.fields if not f.auto_id] - - # 提取向量字段信息用于维度校验 - vec_info = {f.name: f.params['dim'] for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR} + fields = [] + for cfg in fields_config: + field = FieldSchema( + name=cfg["name"], + dtype=self.get_db_type(cfg["type"]), + is_primary=cfg.get("is_primary", False), + auto_id=cfg.get("auto_id", False), + dim=cfg.get("dim"), + max_length=cfg.get("max_length"), + ) + fields.append(field) - columnar_data = {name: [] for name in required_fields} - for i, entry in enumerate(data_dicts): - for field in required_fields: - if field not in entry: - raise ValueError(f"记录 {i} 缺失必填字段: {field}") - # 向量维度校验 - if field in vec_info and len(entry[field]) != vec_info[field]: - raise ValueError(f"记录 {i} 向量维度错误: 预期 {vec_info[field]}, 实际 {len(entry[field])}") - columnar_data[field].append(entry[field]) - - return [columnar_data[f] for f in required_fields] + schema = CollectionSchema(fields, description=description) + collection = Collection(name=collection_name, schema=schema, using=self.alias) - def upsert(self, collection_name, data_dicts): - """通用 Upsert:支持字典列表输入,自动识别主键更新""" - pks = [item['id'] for item in data_dicts] - self.delete(collection_name, pks) - col = Collection(collection_name, using=self.alias) - if self.partitionize is None: - formatted_data = self._validate_and_format(col, data_dicts) - res = col.upsert(formatted_data) - col.flush() # 内网环境强制落盘以防数据丢失 - return res - grouped_data = {} - for entry in data_dicts: - p_name = str(partition_func(entry)) - if p_name not in grouped_data: - grouped_data[p_name] = [] - grouped_data[p_name].append(entry) + for cfg in fields_config: + dbtype = self.get_db_type(cfg["type"]) + if dbtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR): + idx_params = {"index_type": index_type, "metric_type": metric, "params": {"M": 16, "efConstruction": 64}} + if index_params_extra: + idx_params["params"].update(index_params_extra) + collection.create_index(field_name=cfg["name"], index_params=idx_params) - for p_name, p_data in grouped_data.items(): - # 自动维护分区 - if p_name != "_default" and not col.has_partition(p_name): - col.create_partition(p_name) - print(f"🚩 自动创建新分区: {p_name}") - - # 格式化数据 - formatted = self._validate_and_format(col, p_data) - - # 执行 Upsert - col.upsert(formatted, partition_name=p_name) - col.flush() + info(f"Collection '{collection_name}' created, metric={metric}, index={index_type}") + return collection - def delete(self, collection_name, pks, partition_name=None): - """删除指定主键记录""" - col = Collection(collection_name, using=self.alias) - expr = f"id in {pks}" if isinstance(pks, list) else f"id == {pks}" - col.delete(expr, partition_name=partition_name) - print(f"✂️ 已删除主键为 {pks} 的记录。") + def drop_collection(self, collection_name): + if utility.has_collection(collection_name, using=self.alias): + utility.drop_collection(collection_name, using=self.alias) + info(f"Collection '{collection_name}' dropped") - # --- 高级组合检索 --- - def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None): - """ - 组合查询接口:支持向量近似搜索 + 标量过滤 + 分页 - :param vector: 目标向量,若为 None 则退化为纯标量查询 - :param expr: 过滤条件,如 "score > 0.8 and lang == 'cmn'" - :param offset: 分页偏移量 - """ - offset = (page - 1) * pagerows - limit = pagerows - col = Collection(collection_name, using=self.alias) - # 确保数据已加载 - if utility.get_query_segment_info(collection_name, using=self.alias) == []: - self.load_segmented(collection_name) + # === Memory === - if output_fields is None: - output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR] + def load_segmented(self, collection_name, partition_names=None): + col = Collection(collection_name, using=self.alias) + col.load(partition_names=partition_names) + debug(f"Loaded '{collection_name}' partitions={partition_names}") - if vector is not None: - # 自动寻找向量字段名 - vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR) - search_params = {"metric_type": "L2", "params": {"ef": 64}, "offset": offset} - - return { - "total": -1, - "page": page, - "pagerows": pagerows, - "rows": col.search( - data=[vector], - anns_field=vec_field, - param=search_params, - limit=limit, - offset=offset, - expr=expr, - output_fields=output_fields - ) - } - else: - # 纯标量查询 - return { - "total": -1, - "page": page, - "pagerows": pagerows, - "rows": col.query(expr=expr, limit=limit, offset=offset, output_fields=output_fields) - } + def release_memory(self, collection_name): + col = Collection(collection_name, using=self.alias) + col.release() + debug(f"Released '{collection_name}' from memory") + + # === DML === + + def _validate_and_format(self, col, data_dicts): + schema = col.schema + required = [f.name for f in schema.fields if not f.auto_id] + vec_info = {f.name: f.params.get("dim", 0) for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR} + columnar = {name: [] for name in required} + for i, entry in enumerate(data_dicts): + for field in required: + if field not in entry: + raise ValueError(f"Record {i} missing field: {field}") + if field in vec_info and entry[field] is not None: + if len(entry[field]) != vec_info[field]: + raise ValueError(f"Record {i} vector dim mismatch: expected {vec_info[field]}, got {len(entry[field])}") + columnar[field].append(entry[field]) + return [columnar[f] for f in required] + + def upsert(self, collection_name, data_dicts, **kwargs): + if not isinstance(data_dicts, list): + data_dicts = [data_dicts] + pks = [item.get("id") for item in data_dicts if "id" in item] + if pks: + self.delete(collection_name, pks) + col = Collection(collection_name, using=self.alias) + if self.partitionize is None: + formatted = self._validate_and_format(col, data_dicts) + res = col.upsert(formatted) + col.flush() + return res + grouped = {} + for entry in data_dicts: + p_name = str(self.partitionize(entry)) + grouped.setdefault(p_name, []).append(entry) + for p_name, p_data in grouped.items(): + if p_name != "_default" and not col.has_partition(p_name): + col.create_partition(p_name) + formatted = self._validate_and_format(col, p_data) + col.upsert(formatted, partition_name=p_name) + col.flush() + + def batch_insert(self, collection_name, data_dicts, flush=True, partition_name=None, **kwargs): + if not isinstance(data_dicts, list) or not data_dicts: + return None + col = Collection(collection_name, using=self.alias) + formatted = self._validate_and_format(col, data_dicts) + res = col.insert(formatted, partition_name=partition_name) + if flush: + col.flush() + return res + + def delete(self, collection_name, pks, partition_name=None, **kwargs): + col = Collection(collection_name, using=self.alias) + if isinstance(pks, list): + if pks and isinstance(pks[0], str): + pk_str = ", ".join(f'"{p}"' for p in pks) + expr = f"id in [{pk_str}]" + else: + expr = f"id in {pks}" + else: + expr = f'id == "{pks}"' if isinstance(pks, str) else f"id == {pks}" + col.delete(expr, partition_name=partition_name) + col.flush() + + # === Search === + + def _ensure_loaded(self, collection_name): + try: + seg_info = utility.get_query_segment_info(collection_name, using=self.alias) + if not seg_info: + self.load_segmented(collection_name) + except Exception: + try: + self.load_segmented(collection_name) + except Exception: + pass + + def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, + output_fields=None, metric=None, **kwargs): + offset = (page - 1) * pagerows + limit = pagerows + col = Collection(collection_name, using=self.alias) + self._ensure_loaded(collection_name) + + if output_fields is None: + output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR] + + if vector is not None: + vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR) + m = self.metric_map.get(metric, self.default_metric) if metric else self.default_metric + search_params = {"metric_type": m, "params": {"ef": 64}, "offset": offset} + hits = col.search( + data=[vector], + anns_field=vec_field, + param=search_params, + limit=limit, + expr=expr, + output_fields=output_fields, + ) + rows = [] + for hit in hits[0]: + row = {"id": hit.id, "score": hit.score} + for field in output_fields: + if field not in row: + try: + row[field] = hit.entity.get(field) + except Exception: + pass + rows.append(row) + return {"total": -1, "page": page, "pagerows": pagerows, "rows": rows} + else: + rows = col.query(expr=expr, limit=limit, offset=offset, output_fields=output_fields) + return {"total": -1, "page": page, "pagerows": pagerows, "rows": rows} + + def batch_query(self, collection_name, vectors, expr=None, limit=10, + output_fields=None, metric=None, **kwargs): + if not vectors: + return [] + col = Collection(collection_name, using=self.alias) + self._ensure_loaded(collection_name) + + if output_fields is None: + output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR] + + vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR) + m = self.metric_map.get(metric, self.default_metric) if metric else self.default_metric + search_params = {"metric_type": m, "params": {"ef": 64}} + + results = col.search( + data=vectors, + anns_field=vec_field, + param=search_params, + limit=limit, + expr=expr, + output_fields=output_fields, + ) + all_rows = [] + for hits in results: + rows = [] + for hit in hits: + row = {"id": hit.id, "score": hit.score} + for field in output_fields: + if field not in row: + try: + row[field] = hit.entity.get(field) + except Exception: + pass + rows.append(row) + all_rows.append(rows) + return all_rows \ No newline at end of file