feat: v2.0 - standalone support, COSINE metric, batch ops, fix query bug, new APIs

This commit is contained in:
yumoqing 2026-06-14 20:14:59 +08:00
parent cd64579881
commit 1d9e2f3efc
8 changed files with 712 additions and 577 deletions

8
.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
ah.pid
nohup.out
__pycache__/
*.pyc
db/
logs/
files/
wwwroot/

327
README.md
View File

@ -1,208 +1,137 @@
# vdb # VDB - Vector Database Service
## 安装 基于 ahserver + pymilvus 的向量数据库 HTTP 服务,支持 Milvus Lite 和 Milvus Standalone 两种模式。
执行 ## 特性
- 支持 Milvus Lite本地文件和 Milvus StandaloneDocker 部署,千万级)
- COSINE / IP / L2 三种距离度量
- HNSW / IVF_FLAT / IVF_PQ 索引类型可配
- 批量插入batch_insert和批量搜索batch_query
- 向量搜索 + 标量过滤组合查询
- 集合管理list/stats/drop
## 部署
```bash
cd /data/ymq/vdb
bash build.sh deploy # 启动
bash build.sh stop # 停止
bash build.sh status # 状态
``` ```
git clone https://git.opencomputing.cn/yumoqing/vdb.git
cd vdb
bash ./build.sh
```
执行完成后:
1 添加了一个vdb.service 重启后会自动启动向量数据库服务
2 使用sudo systemctl XXX vdb.service开启动停止 重启服务
3 app/vdbapp.py 正在执行
### 切换到 Milvus Standalone千万级
编辑 `conf/config.json`
```json
{
"milvus_mode": "standalone",
"milvus_host": "127.0.0.1:19530"
}
```
部署 Milvus StandaloneDocker
```yaml
# docker-compose.yml
services:
etcd:
image: quay.io/coreos/etcd:v3.5.18
environment:
- ETCD_AUTO_COMPACTION_MODE=revision
- ETCD_AUTO_COMPACTION_RETENTION=1000
volumes: ["etcd:/etcd"]
command: etcd --advertise-client-urls=http://127.0.0.1:2379 --listen-client-urls=http://0.0.0.0:2379 --data-dir=/etcd
minio:
image: minio/minio:latest
environment:
MINIO_ACCESS_KEY: minioadmin
MINIO_SECRET_KEY: minioadmin
volumes: ["minio:/minio_data"]
command: minio server /minio_data --console-address ":9001"
milvus:
image: milvusdb/milvus:v2.4-latest
depends_on: [etcd, minio]
ports: ["19530:19530", "9091:9091"]
volumes: ["milvus:/var/lib/milvus"]
environment:
ETCD_ENDPOINTS: etcd:2379
MINIO_ADDRESS: minio:9000
```
## API 接口
### GET /v1/listcollections
列出所有集合。
### POST /v1/createcollection
创建集合。
```json
{
"colname": "entities",
"fields": [
{"name": "id", "type": "str", "max_length": 32, "is_primary": true},
{"name": "embedding", "type": "fvector", "dim": 1024}
],
"metric": "COSINE",
"index_type": "HNSW"
}
```
### POST /v1/collectionstats
获取集合统计信息(行数、字段、索引)。
### POST /v1/upsert
插入或更新(自动删除旧记录再插入,适合少量)。
### POST /v1/batchinsert
批量插入不逐条flush适合大批量入库
```json
{"colname": "entities", "data": [...], "flush": true}
```
### POST /v1/query
向量搜索 + 标量过滤 + 分页。
```json
{
"colname": "entities",
"vector": [0.1, 0.2, ...],
"expr": "type == \"animal\"",
"pagerows": 20,
"page": 1,
"metric": "COSINE"
}
```
### POST /v1/batchquery
批量向量搜索(多个查询向量同时搜索)。
### POST /v1/delete
按主键删除。
### POST /v1/dropcollection
删除整个集合。
## 字段类型 ## 字段类型
以下字段类型可用在集合的话数据类型中
* "str": 可变长字符串, 需要max_length属性 | 类型 | 说明 | 备注 |
* "int": 整数类型 |------|------|------|
* "bool": 逻辑值类型 | str | 字符串 | 需要 max_length |
* "float": 浮点数 | int | 32位整数 | |
* "fvector": 浮点数向量,常规向量都用此类型, | int64 | 64位整数 | |
* "bvector": 二进制向量 二分向量使用, | bool | 布尔值 | |
* "json": json格式的数据 | float | 浮点数 | |
| fvector | 浮点向量 | 需要 dim |
| bvector | 二进制向量 | 需要 dim |
| json | JSON | 灵活结构 |
一些例子 ## 配置 (conf/config.json)
```
[
# 1. 主键字段 (必选): str 类型非自增手动指定ID以便与业务系统关联
{
"name": "id",
"type": "str",
"max_length": 32
"is_primary": True,
"auto_id": False
},
# 2. 向量字段 (核心): 存储 CLIP 提取的特征768 维 | 字段 | 默认值 | 说明 |
{ |------|--------|------|
"name": "video_embedding", | milvus_mode | lite | lite 或 standalone |
"type": "fvector", | milvus_db | $[workdir]$/db/milvus.db | lite模式文件路径 |
"dim": 768 | milvus_host | 127.0.0.1:19530 | standalone模式地址 |
}, | milvus_metric | COSINE | 默认距离度量 |
| client_max_size | 100MB | 请求体大小上限 |
# 3. 变长字符串: 存储视频存储路径,需指定最大长度
{
"name": "file_path",
"type": "str",
"max_length": 512
},
# 4. 浮点数: 存储评估得分 (如 VBench 的综合分数)
{
"name": "quality_score",
"type": "float"
},
# 5. 布尔值: 标记是否已完成人工复核
{
"name": "is_reviewed",
"type": "bool"
},
# 6. 整数: 存储视频的时长(秒)
{
"name": "duration_sec",
"type": "int"
},
# 7. JSON 字段: 存储非结构化的元数据(如 VBench 的 16 个子维度细节)
# Milvus 2.4+ 支持 JSON 动态解析查询
{
"name": "meta_data",
"type": "json"
}
]
```
## API
对外提供http的接口用nginx做反向代理提供安全并在nginx.conf中提供客户端ip过滤只有登记过的客户端才能使用
### 创建集合
path/v1/createcollection
methodPOST
headers{
"Content-Type": "application/json"
}
data{
colname: 集合名字,必须提供
fields: 集合字段集,请参照字段类型提供
description:可选, 集合描述
}
成功返回
{
"status":"SUCCEEDED"
}
失败返回
{
"status":"FAILED",
"error": 错误信息
}
### 删除集合
path/v1/dropcollection
methodPOST
headers{
"Content-Type": "application/json"
}
data{
colname: 集合名字,必须提供
}
成功返回
{
"status":"SUCCEEDED"
}
失败返回
{
"status":"FAILED",
"error": 错误信息
}
### 向集合插入一到多条记录
path/v1/upsert
methodPOST
headers{
"Content-Type": "application/json"
}
data{
colname: 集合名字,必须提供
data: 数据字典或数据字典数组
}
成功返回
{
"status":"SUCCEEDED"
}
失败返回
{
"status":"FAILED",
"error": 错误信息
}
### 删除集合一到多条记录
path/v1/delete
methodPOST
headers{
"Content-Type": "application/json"
}
data{
colname: 集合名字,必须提供
pks: 主键或主键数组
description:可选, 集合描述
}
成功返回
{
"status":"SUCCEEDED"
}
失败返回
{
"status":"FAILED",
"error": 错误信息
}
### 查询集合数据
path/v1/createcollection
methodPOST
headers{
"Content-Type": "application/json"
}
data{
colname: 集合名字,必须提供
fields: 集合字段集,请参照字段类型提供
description:可选, 集合描述
}
成功返回
{
"status":"SUCCEEDED"
"data": 返回数据
}
失败返回
{
"status":"FAILED",
"error": 错误信息
}
返回数据有如下结构
{
"total": -1, # 不知道总共多少条符合条件的数据
"page": 当前页起始1
"pagerows": 每页记录数
"rows": 记录数据
}

107
build.sh
View File

@ -1,45 +1,68 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# clone from git@git.opencomputing.cn/yumoqing/vdb # vdb - Vector Database Service
# git clone https://git.opencomputing.cn/yumoqing/vdb set -e
cdir=$(pwd) cd "$(dirname "$0")"
uname=$(id -un)
gname=$(id -gn)
sudo apt install redis-server
python3 -m venv py3
source py3/bin/activate
pip install .
mkdir $cdir/logs
cd $cdir
cat > $cdir/vdb.service <<EOF
[Unit]
Wants=systemd-networkd.service
[Service] SERVICE_NAME="vdb"
User=$uname PORT=8886
Group=$gname PY=/data/ymq/wan22-service/py3/bin/python
Type=forking action="${1:-status}"
WorkingDirectory=$cdir
ExecStart=$cdir/start.sh
ExecStop=$cdir/stop.sh
StandardOutput=append:/var/log/vdb/vdb.log
StandardError=append:/var/log/vdb/vdb.log
SyslogIdentifier=vdb
[Install] case "$action" in
WantedBy=multi-user.target deploy|update)
EOF echo "=== $SERVICE_NAME Deploy ==="
cat > $cdir/start.sh <<EOF if [ -f ah.pid ] && kill -0 $(cat ah.pid) 2>/dev/null; then
#!/usr/bin/bash kill $(cat ah.pid) 2>/dev/null || true
cd $cdir sleep 2
$cdir/py3/bin/python $cdir/app/vdbapp.py -p 8887 -w $cdir & fi
exit 0 if [ -d .git ]; then
EOF git pull origin master 2>/dev/null || git pull origin main 2>/dev/null || true
cat > $cdir/stop.sh <<EOF fi
PID=\$(lsof -t -i:8887) mkdir -p logs db
kill -9 \$PID export PYTHONPATH="$(pwd)"
EOF nohup $PY app/vdbapp.py > nohup.out 2>&1 &
chmod +x $cdir/start.sh stop.sh echo $! > ah.pid
sudo mkdir /var/log/vdb echo "Started PID $(cat ah.pid) on port $PORT"
sudo cp vdb.service /etc/systemd/system sleep 3
sudo systemctl enable vdb if curl -s http://localhost:$PORT/v1/listcollections > /dev/null 2>&1; then
sudo systemctl restart vdb echo "Service healthy"
else
echo "WARNING: not responding yet, check nohup.out"
tail -20 nohup.out
fi
;;
stop)
if [ -f ah.pid ]; then
kill $(cat ah.pid) 2>/dev/null || true
rm -f ah.pid
echo "Stopped"
else
echo "Not running"
fi
;;
start)
mkdir -p logs db
export PYTHONPATH="$(pwd)"
nohup $PY app/vdbapp.py > nohup.out 2>&1 &
echo $! > ah.pid
echo "Started PID $(cat ah.pid)"
;;
status)
echo "=== $SERVICE_NAME Status ==="
if [ -f ah.pid ] && kill -0 $(cat ah.pid) 2>/dev/null; then
echo "Process: running (PID $(cat ah.pid))"
else
echo "Process: not running"
fi
echo "Port: $PORT"
if curl -s --max-time 3 http://localhost:$PORT/v1/listcollections > /dev/null 2>&1; then
echo "HTTP: OK"
else
echo "HTTP: not responding"
fi
;;
*)
echo "Usage: $0 {deploy|update|stop|start|status}"
exit 1
;;
esac

View File

@ -1,7 +1,12 @@
{ {
"vdb_type": "milvus", "vdb_type": "milvus",
"filesroot": "$[workdir]$/files", "milvus_mode": "lite",
"milvus_db": "$[workdir]$/db/milvus.db", "milvus_db": "$[workdir]$/db/milvus.db",
"milvus_host": "127.0.0.1:19530",
"milvus_token": "",
"milvus_dbname": "default",
"milvus_metric": "COSINE",
"filesroot": "$[workdir]$/files",
"logger": { "logger": {
"name": "vdb", "name": "vdb",
"levelname": "info", "levelname": "info",
@ -11,7 +16,7 @@
"paths": [ "paths": [
["$[workdir]$/wwwroot", ""] ["$[workdir]$/wwwroot", ""]
], ],
"client_max_size": 10000, "client_max_size": 104857600,
"host": "0.0.0.0", "host": "0.0.0.0",
"port": 8886, "port": 8886,
"coding": "utf-8", "coding": "utf-8",
@ -34,9 +39,21 @@
"registerfunction": "drop_collection" "registerfunction": "drop_collection"
}, },
{ {
"leading": "/v1/iupsert", "leading": "/v1/listcollections",
"registerfunction": "list_collections"
},
{
"leading": "/v1/collectionstats",
"registerfunction": "collection_stats"
},
{
"leading": "/v1/upsert",
"registerfunction": "upsert" "registerfunction": "upsert"
}, },
{
"leading": "/v1/batchinsert",
"registerfunction": "batch_insert"
},
{ {
"leading": "/v1/delete", "leading": "/v1/delete",
"registerfunction": "delete" "registerfunction": "delete"
@ -45,6 +62,10 @@
"leading": "/v1/query", "leading": "/v1/query",
"registerfunction": "query" "registerfunction": "query"
}, },
{
"leading": "/v1/batchquery",
"registerfunction": "batch_query"
},
{ {
"leading": "/docs", "leading": "/docs",
"registerfunction": "docs" "registerfunction": "docs"

View File

@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "vdb" name = "vdb"
version = "0.1.0" version = "0.2.0"
description = "向量数据库服务" description = "Vector Database Service (Milvus Lite + Standalone)"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
@ -17,5 +17,5 @@ dependencies = [
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["."] # 声明在哪个目录下查找包,默认是当前目录 where = ["."]
include = ["vdb"] # 包含哪些包 include = ["vdb"]

View File

@ -1,24 +1,44 @@
from abc import ABC, abstractmethod
class BaseVDB:
def query(self, collection_name, class BaseVDB(ABC):
vector=None, """向量数据库抽象基类"""
expr=None,
pagerows=80, @abstractmethod
page=1, def create_collection(self, collection_name, fields_config, description="", **kwargs):
output_fields=None):
pass
def get_db_type(dtype: str):
return self.dbtypes.get(dtype)
def upsert(self, collection_name, data_dicts, partition_name=None):
pass
def delete(self, collection_name, pks, partition_name=None):
pass pass
@abstractmethod
def drop_collection(self, collection_name): def drop_collection(self, collection_name):
pass pass
def create_collection(self, collection_name, fields_config, description=""): @abstractmethod
def list_collections(self):
pass pass
@abstractmethod
def collection_stats(self, collection_name):
pass
@abstractmethod
def upsert(self, collection_name, data_dicts, **kwargs):
pass
@abstractmethod
def batch_insert(self, collection_name, data_dicts, **kwargs):
pass
@abstractmethod
def delete(self, collection_name, pks, **kwargs):
pass
@abstractmethod
def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None, **kwargs):
pass
@abstractmethod
def batch_query(self, collection_name, vectors, expr=None, limit=10, output_fields=None, **kwargs):
pass
def get_db_type(self, dtype: str):
return self.dbtypes.get(dtype) if hasattr(self, "dbtypes") else None

View File

@ -1,5 +1,4 @@
from traceback import format_exc from traceback import format_exc
from functools import partial
from ahserver.serverenv import ServerEnv from ahserver.serverenv import ServerEnv
from appPublic.worker import awaitify from appPublic.worker import awaitify
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
@ -7,82 +6,110 @@ from appPublic.log import debug, exception
from appPublic.jsonConfig import getConfig from appPublic.jsonConfig import getConfig
from .milvus import MilvusManager from .milvus import MilvusManager
def ownerparting(data): def ownerparting(data):
if data.get('ownerid'): if data.get("ownerid"):
return data.get('ownerid') return data.get("ownerid")
return '_default' return "_default"
async def create_collection(request, params_kw, *args, **kwargs): async def create_collection(request, params_kw, *args, **kwargs):
colname = params_kw.colname colname = params_kw.colname
fields = params_kw.fields fields = params_kw.fields
description = params_kw.description or "" description = params_kw.description or ""
metric = params_kw.metric
index_type = params_kw.index_type or "HNSW"
env = request._run_ns env = request._run_ns
f = awaitify(env.vdb.create_collection) f = awaitify(env.vdb.create_collection)
try: try:
r = await f(colname, fields, description) await f(colname, fields, description, metric=metric, index_type=index_type)
return { return {"status": "SUCCEEDED"}
"status":"SUCCEEDED",
}
except Exception as e: except Exception as e:
exception(f'{e}, {format_exc()}') exception(f"{e}, {format_exc()}")
return { return {"status": "FAILED", "error": str(e)}
"status":"FAILED",
"error": f"{e}"
}
async def drop_collection(request, params_kw, *args, **kwargs): async def drop_collection(request, params_kw, *args, **kwargs):
colname = params_kw.colname colname = params_kw.colname
env = request._run_ns env = request._run_ns
f = awaitify(env.vdb.drop_collection) f = awaitify(env.vdb.drop_collection)
try: try:
r = await f(colname) await f(colname)
return { return {"status": "SUCCEEDED"}
"status": "SUCCEEDED"
}
except Exception as e: except Exception as e:
exception(f"{e}, {format_exc()}") exception(f"{e}, {format_exc()}")
return { return {"status": "FAILED", "error": str(e)}
"status": "FAILED",
"error": f"{e}"
} async def list_collections(request, params_kw, *args, **kwargs):
env = request._run_ns
f = awaitify(env.vdb.list_collections)
try:
result = await f()
return {"status": "SUCCEEDED", "collections": result}
except Exception as e:
exception(f"{e}, {format_exc()}")
return {"status": "FAILED", "error": str(e)}
async def collection_stats(request, params_kw, *args, **kwargs):
colname = params_kw.colname
env = request._run_ns
f = awaitify(env.vdb.collection_stats)
try:
result = await f(colname)
if result is None:
return {"status": "FAILED", "error": f"Collection '{colname}' not found"}
return {"status": "SUCCEEDED", "data": result}
except Exception as e:
exception(f"{e}, {format_exc()}")
return {"status": "FAILED", "error": str(e)}
async def upsert(request, params_kw, *args, **kwargs): async def upsert(request, params_kw, *args, **kwargs):
colname = params_kw.colname colname = params_kw.colname
data = params_kw.data data = params_kw.data
if not isinstance(data, list): if not isinstance(data, list):
data = [data] data = [data]
env = request._run_ns env = request._run_ns
f = awaitify(env.vdb.upsert) f = awaitify(env.vdb.upsert)
try: try:
r = await f(colname, data) await f(colname, data)
return { return {"status": "SUCCEEDED"}
"status": "SUCCEEDED"
}
except Exception as e: except Exception as e:
exception(f"{e}, {format_exc()}") exception(f"{e}, {format_exc()}")
return { return {"status": "FAILED", "error": str(e)}
"status": "FAILED",
"error": f"{e}"
} async def batch_insert(request, params_kw, *args, **kwargs):
colname = params_kw.colname
data = params_kw.data
flush = params_kw.flush if params_kw.flush is not None else True
partition_name = params_kw.partition_name
if not isinstance(data, list):
return {"status": "FAILED", "error": "data must be a list"}
env = request._run_ns
f = awaitify(env.vdb.batch_insert)
try:
await f(colname, data, flush=flush, partition_name=partition_name)
return {"status": "SUCCEEDED", "count": len(data)}
except Exception as e:
exception(f"{e}, {format_exc()}")
return {"status": "FAILED", "error": str(e)}
async def delete(request, params_kw, *args, **kwargs): async def delete(request, params_kw, *args, **kwargs):
colname = params_kw.colname colname = params_kw.colname
pks = params_kw.pks pks = params_kw.pks
env = request._run_ns env = request._run_ns
f = awaitify(env.vdb.delete) f = awaitify(env.vdb.delete)
try: try:
r = await f(colname, pks) await f(colname, pks)
return { return {"status": "SUCCEEDED"}
"status": "SUCCEEDED"
}
except Exception as e: except Exception as e:
exception(f"{e}, {format_exc()}") exception(f"{e}, {format_exc()}")
return { return {"status": "FAILED", "error": str(e)}
"status": "FAILED",
"error": f"{e}"
}
async def query(request, params_kw, *args, **kwargs): async def query(request, params_kw, *args, **kwargs):
colname = params_kw.colname colname = params_kw.colname
@ -91,34 +118,56 @@ async def query(request, params_kw, *args, **kwargs):
pagerows = params_kw.pagerows or 80 pagerows = params_kw.pagerows or 80
page = params_kw.page or 1 page = params_kw.page or 1
output_fields = params_kw.output_fields output_fields = params_kw.output_fields
metric = params_kw.metric
env = request._run_ns env = request._run_ns
f1 = awaitify(env.vdb.drop_collection) f = awaitify(env.vdb.query)
try: try:
f = partial(f1, colname, vector=vector, expr=expr, pagerows=pagerows, page=page, output_fields=output_fields) result = await f(colname, vector=vector, expr=expr, pagerows=pagerows,
r = await f() page=page, output_fields=output_fields, metric=metric)
return { return {"status": "SUCCEEDED", "data": result}
"status": "SUCCEEDED",
"data": r
}
except Exception as e: except Exception as e:
exception(f"{e}, {format_exc()}") exception(f"{e}, {format_exc()}")
return { return {"status": "FAILED", "error": str(e)}
"status": "FAILED",
"error": f"{e}"
} async def batch_query(request, params_kw, *args, **kwargs):
colname = params_kw.colname
vectors = params_kw.vectors
expr = params_kw.expr
limit = params_kw.limit or 10
output_fields = params_kw.output_fields
metric = params_kw.metric
if not isinstance(vectors, list) or not vectors:
return {"status": "FAILED", "error": "vectors must be a non-empty list"}
env = request._run_ns
f = awaitify(env.vdb.batch_query)
try:
result = await f(colname, vectors, expr=expr, limit=limit,
output_fields=output_fields, metric=metric)
return {"status": "SUCCEEDED", "data": result}
except Exception as e:
exception(f"{e}, {format_exc()}")
return {"status": "FAILED", "error": str(e)}
def load_vdb(): def load_vdb():
config = getConfig() config = getConfig()
vdb = None vdb = None
vdb_type = config.vdb_type vdb_type = config.vdb_type
if vdb_type == 'milvus': if vdb_type == "milvus":
vdb = MilvusManager(partitionize=ownerparting) vdb = MilvusManager(partitionize=ownerparting)
env = ServerEnv() env = ServerEnv()
env.vdb = vdb env.vdb = vdb
rf = RegisterFunction() rf = RegisterFunction()
rf.register('create_collection', create_collection) rf.register("create_collection", create_collection)
rf.register('drop_collection', drop_collection) rf.register("drop_collection", drop_collection)
rf.register('upsert', upsert) rf.register("list_collections", list_collections)
rf.register('delete', delete) rf.register("collection_stats", collection_stats)
rf.register('query', query) rf.register("upsert", upsert)
rf.register("batch_insert", batch_insert)
rf.register("delete", delete)
rf.register("query", query)
rf.register("batch_query", batch_query)

View File

@ -5,22 +5,30 @@ from threading import Lock
from pymilvus import ( from pymilvus import (
connections, FieldSchema, CollectionSchema, connections, FieldSchema, CollectionSchema,
DataType, Collection, utility, Partition DataType, Collection, utility
) )
from .basevdb import BaseVDB from .basevdb import BaseVDB
class MilvusManager(BaseVDB): class MilvusManager(BaseVDB):
_instance = None _instance = None
_lock = Lock() _lock = Lock()
dbtypes = { dbtypes = {
"str": DataType.VARCHAR, "str": DataType.VARCHAR,
"int": DataType.INT32, "int": DataType.INT32,
"int64": DataType.INT64,
"bool": DataType.BOOL, "bool": DataType.BOOL,
"float": DataType.FLOAT, "float": DataType.FLOAT,
"fvector": DataType.FLOAT_VECTOR, "fvector": DataType.FLOAT_VECTOR,
"bvector": DataType.BINARY_VECTOR, "bvector": DataType.BINARY_VECTOR,
"json": DataType.JSON "json": DataType.JSON,
} }
metric_map = {
"L2": "L2", "l2": "L2",
"IP": "IP", "ip": "IP",
"COSINE": "COSINE", "cosine": "COSINE",
}
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
with cls._lock: with cls._lock:
if cls._instance is None: if cls._instance is None:
@ -34,194 +42,271 @@ class MilvusManager(BaseVDB):
return return
try: try:
config = getConfig() config = getConfig()
self.alias = "default"
self.mode = getattr(config, "milvus_mode", "lite")
if self.mode == "standalone":
self.host = config.milvus_host
self.token = getattr(config, "milvus_token", "")
self.db_name = getattr(config, "milvus_dbname", "default")
else:
self.db_path = config.milvus_db self.db_path = config.milvus_db
debug(f"dbpath: {self.db_path}") self.default_metric = getattr(config, "milvus_metric", "COSINE")
except KeyError as e:
error(f"配置文件缺少必要字段: {str(e)}")
raise RuntimeError(f"配置文件缺少必要字段: {str(e)}")
self._initialize_connection() self._initialize_connection()
self._initialized = True self._initialized = True
info(f"MilvusDBConnection initialized with db_path: {self.db_path}") info(f"MilvusManager initialized: mode={self.mode}, alias={self.alias}")
except Exception as e:
error(f"MilvusManager init failed: {e}")
raise RuntimeError(f"MilvusManager init failed: {e}")
def _initialize_connection(self): def _initialize_connection(self):
"""初始化 Milvus 连接,确保单一连接"""
try: try:
db_dir = os.path.dirname(self.db_path) if self.mode == "standalone":
debug(f"db_dir: {db_dir}") parts = self.host.split(":")
if not os.path.exists(db_dir): host = parts[0]
os.makedirs(db_dir, exist_ok=True) port = parts[1] if len(parts) > 1 else "19530"
debug(f"创建 Milvus 目录: {db_dir}") uri = f"http://{host}:{port}"
if not os.access(db_dir, os.W_OK): if not connections.has_connection(self.alias):
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写") kw = {"uri": uri, "alias": self.alias}
debug(f"不可写") if self.token:
if not connections.has_connection(self.db_path): kw["token"] = self.token
connections.connect(self.db_path, uri=self.db_path) if self.db_name and self.db_name != "default":
debug(f"已连接到 Milvus Lite路径: {self.db_path}") kw["db_name"] = self.db_name
connections.connect(**kw)
info(f"Connected to Milvus Standalone: {uri}")
else: else:
debug("已存在 Milvus 连接,跳过重复连接") debug("Milvus Standalone connection already exists")
else:
db_dir = os.path.dirname(self.db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
if not connections.has_connection(self.alias):
connections.connect(alias=self.alias, uri=self.db_path)
info(f"Connected to Milvus Lite: {self.db_path}")
else:
debug("Milvus Lite connection already exists")
except Exception as e: except Exception as e:
error(f"连接 Milvus 失败: {str(e)}") error(f"Milvus connect failed: {e}")
raise RuntimeError(f"连接 Milvus 失败: {str(e)}") raise RuntimeError(f"Milvus connect failed: {e}")
# --- 集合管理 --- # === Collection Management ===
def create_collection(self, collection_name, fields_config, description=""):
""" def list_collections(self):
打开或创建集合如果已存在则返回对象不存在则根据 config 创建 return utility.list_collections(using=self.alias)
:param fields_config: 格式: [{"name": "id", "type": DataType.INT64, "is_primary": True}, ...]
""" def collection_stats(self, collection_name):
if not utility.has_collection(collection_name, using=self.alias):
return None
col = Collection(collection_name, using=self.alias)
stats = {"name": collection_name, "row_count": col.num_entities}
stats["fields"] = [
{"name": f.name, "type": str(f.dtype), "is_primary": f.is_primary}
for f in col.schema.fields
]
try:
stats["indexes"] = [
{"field": idx.field_name, "type": idx.params.get("index_type", "?")}
for idx in col.indexes
]
except Exception:
stats["indexes"] = []
return stats
def create_collection(self, collection_name, fields_config, description="",
metric=None, index_type="HNSW", index_params_extra=None):
if utility.has_collection(collection_name, using=self.alias): if utility.has_collection(collection_name, using=self.alias):
# print(f"📦 集合 '{collection_name}' 已存在,直接加载。")
return Collection(collection_name, using=self.alias) return Collection(collection_name, using=self.alias)
metric = self.metric_map.get(metric, self.default_metric) if metric else self.default_metric
fields = [] fields = []
for cfg in fields_config: for cfg in fields_config:
field = FieldSchema( field = FieldSchema(
name=cfg['name'], name=cfg["name"],
dtype=self.get_db_type(cfg['type']), dtype=self.get_db_type(cfg["type"]),
is_primary=cfg.get('is_primary', False), is_primary=cfg.get("is_primary", False),
auto_id=cfg.get('auto_id', False), auto_id=cfg.get("auto_id", False),
dim=cfg.get('dim'), dim=cfg.get("dim"),
max_length=cfg.get('max_length') max_length=cfg.get("max_length"),
) )
fields.append(field) fields.append(field)
schema = CollectionSchema(fields, description=description) schema = CollectionSchema(fields, description=description)
collection = Collection(name=collection_name, schema=schema, using=self.alias) collection = Collection(name=collection_name, schema=schema, using=self.alias)
# 自动为向量字段创建索引 (内网推荐 HNSW)
for cfg in fields_config: for cfg in fields_config:
dbtype = self.get_db_type(cfg['type']) dbtype = self.get_db_type(cfg["type"])
if dbtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: if dbtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
index_params = { idx_params = {"index_type": index_type, "metric_type": metric, "params": {"M": 16, "efConstruction": 64}}
"metric_type": "L2", if index_params_extra:
"index_type": "HNSW", idx_params["params"].update(index_params_extra)
"params": {"M": 16, "efConstruction": 64} collection.create_index(field_name=cfg["name"], index_params=idx_params)
}
collection.create_index(field_name=cfg['name'], index_params=index_params)
# print(f"🚀 集合 '{collection_name}' 创建并初始化索引完成。") info(f"Collection '{collection_name}' created, metric={metric}, index={index_type}")
return collection return collection
def drop_collection(self, collection_name): def drop_collection(self, collection_name):
"""物理删除集合"""
if utility.has_collection(collection_name, using=self.alias): if utility.has_collection(collection_name, using=self.alias):
utility.drop_collection(collection_name, using=self.alias) utility.drop_collection(collection_name, using=self.alias)
print(f"🗑️ 集合 '{collection_name}' 已从磁盘删除。") info(f"Collection '{collection_name}' dropped")
# === Memory ===
# --- 内存与分区优化 ---
def load_segmented(self, collection_name, partition_names=None): def load_segmented(self, collection_name, partition_names=None):
"""分段加载:只加载特定分区到内存以保护内网服务器内存"""
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
col.load(partition_names=partition_names) col.load(partition_names=partition_names)
print(f"🧠 已加载集合数据(分区: {partition_names if partition_names else '全量'})到内存。") debug(f"Loaded '{collection_name}' partitions={partition_names}")
def release_memory(self, collection_name): def release_memory(self, collection_name):
"""释放内存:在执行 VBench 推理等高显存任务前调用"""
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
col.release() col.release()
print(f"♻️ 集合 '{collection_name}' 已释放内存占用。") debug(f"Released '{collection_name}' from memory")
# === DML ===
# --- 数据操作 (DML) ---
def _validate_and_format(self, col, data_dicts): def _validate_and_format(self, col, data_dicts):
"""私有工具:校验字典完整性并转换为列式数据"""
schema = col.schema schema = col.schema
# 找出需要填充的字段名排除自增ID required = [f.name for f in schema.fields if not f.auto_id]
required_fields = [f.name for f in schema.fields if not f.auto_id] vec_info = {f.name: f.params.get("dim", 0) for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR}
columnar = {name: [] for name in required}
# 提取向量字段信息用于维度校验
vec_info = {f.name: f.params['dim'] for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR}
columnar_data = {name: [] for name in required_fields}
for i, entry in enumerate(data_dicts): for i, entry in enumerate(data_dicts):
for field in required_fields: for field in required:
if field not in entry: if field not in entry:
raise ValueError(f"记录 {i} 缺失必填字段: {field}") raise ValueError(f"Record {i} missing field: {field}")
# 向量维度校验 if field in vec_info and entry[field] is not None:
if field in vec_info and len(entry[field]) != vec_info[field]: if len(entry[field]) != vec_info[field]:
raise ValueError(f"记录 {i} 向量维度错误: 预期 {vec_info[field]}, 实际 {len(entry[field])}") raise ValueError(f"Record {i} vector dim mismatch: expected {vec_info[field]}, got {len(entry[field])}")
columnar_data[field].append(entry[field]) columnar[field].append(entry[field])
return [columnar[f] for f in required]
return [columnar_data[f] for f in required_fields] def upsert(self, collection_name, data_dicts, **kwargs):
if not isinstance(data_dicts, list):
def upsert(self, collection_name, data_dicts): data_dicts = [data_dicts]
"""通用 Upsert支持字典列表输入自动识别主键更新""" pks = [item.get("id") for item in data_dicts if "id" in item]
pks = [item['id'] for item in data_dicts] if pks:
self.delete(collection_name, pks) self.delete(collection_name, pks)
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
if self.partitionize is None: if self.partitionize is None:
formatted_data = self._validate_and_format(col, data_dicts) formatted = self._validate_and_format(col, data_dicts)
res = col.upsert(formatted_data) res = col.upsert(formatted)
col.flush() # 内网环境强制落盘以防数据丢失 col.flush()
return res return res
grouped_data = {} grouped = {}
for entry in data_dicts: for entry in data_dicts:
p_name = str(partition_func(entry)) p_name = str(self.partitionize(entry))
if p_name not in grouped_data: grouped.setdefault(p_name, []).append(entry)
grouped_data[p_name] = [] for p_name, p_data in grouped.items():
grouped_data[p_name].append(entry)
for p_name, p_data in grouped_data.items():
# 自动维护分区
if p_name != "_default" and not col.has_partition(p_name): if p_name != "_default" and not col.has_partition(p_name):
col.create_partition(p_name) col.create_partition(p_name)
print(f"🚩 自动创建新分区: {p_name}")
# 格式化数据
formatted = self._validate_and_format(col, p_data) formatted = self._validate_and_format(col, p_data)
# 执行 Upsert
col.upsert(formatted, partition_name=p_name) col.upsert(formatted, partition_name=p_name)
col.flush() col.flush()
def delete(self, collection_name, pks, partition_name=None): def batch_insert(self, collection_name, data_dicts, flush=True, partition_name=None, **kwargs):
"""删除指定主键记录""" if not isinstance(data_dicts, list) or not data_dicts:
return None
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
expr = f"id in {pks}" if isinstance(pks, list) else f"id == {pks}" formatted = self._validate_and_format(col, data_dicts)
col.delete(expr, partition_name=partition_name) res = col.insert(formatted, partition_name=partition_name)
print(f"✂️ 已删除主键为 {pks} 的记录。") if flush:
col.flush()
return res
# --- 高级组合检索 --- def delete(self, collection_name, pks, partition_name=None, **kwargs):
def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None): col = Collection(collection_name, using=self.alias)
""" if isinstance(pks, list):
组合查询接口支持向量近似搜索 + 标量过滤 + 分页 if pks and isinstance(pks[0], str):
:param vector: 目标向量若为 None 则退化为纯标量查询 pk_str = ", ".join(f'"{p}"' for p in pks)
:param expr: 过滤条件 "score > 0.8 and lang == 'cmn'" expr = f"id in [{pk_str}]"
:param offset: 分页偏移量 else:
""" expr = f"id in {pks}"
else:
expr = f'id == "{pks}"' if isinstance(pks, str) else f"id == {pks}"
col.delete(expr, partition_name=partition_name)
col.flush()
# === Search ===
def _ensure_loaded(self, collection_name):
try:
seg_info = utility.get_query_segment_info(collection_name, using=self.alias)
if not seg_info:
self.load_segmented(collection_name)
except Exception:
try:
self.load_segmented(collection_name)
except Exception:
pass
def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1,
output_fields=None, metric=None, **kwargs):
offset = (page - 1) * pagerows offset = (page - 1) * pagerows
limit = pagerows limit = pagerows
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
# 确保数据已加载 self._ensure_loaded(collection_name)
if utility.get_query_segment_info(collection_name, using=self.alias) == []:
self.load_segmented(collection_name)
if output_fields is None: if output_fields is None:
output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR] output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR]
if vector is not None: if vector is not None:
# 自动寻找向量字段名
vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR) vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR)
search_params = {"metric_type": "L2", "params": {"ef": 64}, "offset": offset} m = self.metric_map.get(metric, self.default_metric) if metric else self.default_metric
search_params = {"metric_type": m, "params": {"ef": 64}, "offset": offset}
return { hits = col.search(
"total": -1,
"page": page,
"pagerows": pagerows,
"rows": col.search(
data=[vector], data=[vector],
anns_field=vec_field, anns_field=vec_field,
param=search_params, param=search_params,
limit=limit, limit=limit,
offset=offset,
expr=expr, expr=expr,
output_fields=output_fields output_fields=output_fields,
) )
} rows = []
for hit in hits[0]:
row = {"id": hit.id, "score": hit.score}
for field in output_fields:
if field not in row:
try:
row[field] = hit.entity.get(field)
except Exception:
pass
rows.append(row)
return {"total": -1, "page": page, "pagerows": pagerows, "rows": rows}
else: else:
# 纯标量查询 rows = col.query(expr=expr, limit=limit, offset=offset, output_fields=output_fields)
return { return {"total": -1, "page": page, "pagerows": pagerows, "rows": rows}
"total": -1,
"page": page, def batch_query(self, collection_name, vectors, expr=None, limit=10,
"pagerows": pagerows, output_fields=None, metric=None, **kwargs):
"rows": col.query(expr=expr, limit=limit, offset=offset, output_fields=output_fields) if not vectors:
} return []
col = Collection(collection_name, using=self.alias)
self._ensure_loaded(collection_name)
if output_fields is None:
output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR]
vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR)
m = self.metric_map.get(metric, self.default_metric) if metric else self.default_metric
search_params = {"metric_type": m, "params": {"ef": 64}}
results = col.search(
data=vectors,
anns_field=vec_field,
param=search_params,
limit=limit,
expr=expr,
output_fields=output_fields,
)
all_rows = []
for hits in results:
rows = []
for hit in hits:
row = {"id": hit.id, "score": hit.score}
for field in output_fields:
if field not in row:
try:
row[field] = hit.entity.get(field)
except Exception:
pass
rows.append(row)
all_rows.append(rows)
return all_rows