feat: v2.0 - standalone support, COSINE metric, batch ops, fix query bug, new APIs
This commit is contained in:
parent
cd64579881
commit
1d9e2f3efc
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
ah.pid
|
||||
nohup.out
|
||||
__pycache__/
|
||||
*.pyc
|
||||
db/
|
||||
logs/
|
||||
files/
|
||||
wwwroot/
|
||||
327
README.md
327
README.md
@ -1,208 +1,137 @@
|
||||
# vdb
|
||||
# VDB - Vector Database Service
|
||||
|
||||
## 安装
|
||||
基于 ahserver + pymilvus 的向量数据库 HTTP 服务,支持 Milvus Lite 和 Milvus Standalone 两种模式。
|
||||
|
||||
执行
|
||||
## 特性
|
||||
|
||||
- 支持 Milvus Lite(本地文件)和 Milvus Standalone(Docker 部署,千万级)
|
||||
- COSINE / IP / L2 三种距离度量
|
||||
- HNSW / IVF_FLAT / IVF_PQ 索引类型可配
|
||||
- 批量插入(batch_insert)和批量搜索(batch_query)
|
||||
- 向量搜索 + 标量过滤组合查询
|
||||
- 集合管理(list/stats/drop)
|
||||
|
||||
## 部署
|
||||
|
||||
```bash
|
||||
cd /data/ymq/vdb
|
||||
bash build.sh deploy # 启动
|
||||
bash build.sh stop # 停止
|
||||
bash build.sh status # 状态
|
||||
```
|
||||
git clone https://git.opencomputing.cn/yumoqing/vdb.git
|
||||
cd vdb
|
||||
bash ./build.sh
|
||||
```
|
||||
执行完成后:
|
||||
1 添加了一个vdb.service, 重启后会自动启动向量数据库服务
|
||||
2 使用sudo systemctl XXX vdb.service开启动,停止, 重启服务
|
||||
3 app/vdbapp.py 正在执行
|
||||
|
||||
### 切换到 Milvus Standalone(千万级)
|
||||
|
||||
编辑 `conf/config.json`:
|
||||
```json
|
||||
{
|
||||
"milvus_mode": "standalone",
|
||||
"milvus_host": "127.0.0.1:19530"
|
||||
}
|
||||
```
|
||||
|
||||
部署 Milvus Standalone(Docker):
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
services:
|
||||
etcd:
|
||||
image: quay.io/coreos/etcd:v3.5.18
|
||||
environment:
|
||||
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||
volumes: ["etcd:/etcd"]
|
||||
command: etcd --advertise-client-urls=http://127.0.0.1:2379 --listen-client-urls=http://0.0.0.0:2379 --data-dir=/etcd
|
||||
|
||||
minio:
|
||||
image: minio/minio:latest
|
||||
environment:
|
||||
MINIO_ACCESS_KEY: minioadmin
|
||||
MINIO_SECRET_KEY: minioadmin
|
||||
volumes: ["minio:/minio_data"]
|
||||
command: minio server /minio_data --console-address ":9001"
|
||||
|
||||
milvus:
|
||||
image: milvusdb/milvus:v2.4-latest
|
||||
depends_on: [etcd, minio]
|
||||
ports: ["19530:19530", "9091:9091"]
|
||||
volumes: ["milvus:/var/lib/milvus"]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: etcd:2379
|
||||
MINIO_ADDRESS: minio:9000
|
||||
```
|
||||
|
||||
## API 接口
|
||||
|
||||
### GET /v1/listcollections
|
||||
列出所有集合。
|
||||
|
||||
### POST /v1/createcollection
|
||||
创建集合。
|
||||
```json
|
||||
{
|
||||
"colname": "entities",
|
||||
"fields": [
|
||||
{"name": "id", "type": "str", "max_length": 32, "is_primary": true},
|
||||
{"name": "embedding", "type": "fvector", "dim": 1024}
|
||||
],
|
||||
"metric": "COSINE",
|
||||
"index_type": "HNSW"
|
||||
}
|
||||
```
|
||||
|
||||
### POST /v1/collectionstats
|
||||
获取集合统计信息(行数、字段、索引)。
|
||||
|
||||
### POST /v1/upsert
|
||||
插入或更新(自动删除旧记录再插入,适合少量)。
|
||||
|
||||
### POST /v1/batchinsert
|
||||
批量插入(不逐条flush,适合大批量入库)。
|
||||
```json
|
||||
{"colname": "entities", "data": [...], "flush": true}
|
||||
```
|
||||
|
||||
### POST /v1/query
|
||||
向量搜索 + 标量过滤 + 分页。
|
||||
```json
|
||||
{
|
||||
"colname": "entities",
|
||||
"vector": [0.1, 0.2, ...],
|
||||
"expr": "type == \"animal\"",
|
||||
"pagerows": 20,
|
||||
"page": 1,
|
||||
"metric": "COSINE"
|
||||
}
|
||||
```
|
||||
|
||||
### POST /v1/batchquery
|
||||
批量向量搜索(多个查询向量同时搜索)。
|
||||
|
||||
### POST /v1/delete
|
||||
按主键删除。
|
||||
|
||||
### POST /v1/dropcollection
|
||||
删除整个集合。
|
||||
|
||||
## 字段类型
|
||||
以下字段类型可用在集合的话数据类型中
|
||||
|
||||
* "str": 可变长字符串, 需要max_length属性
|
||||
* "int": 整数类型
|
||||
* "bool": 逻辑值类型
|
||||
* "float": 浮点数
|
||||
* "fvector": 浮点数向量,常规向量都用此类型,
|
||||
* "bvector": 二进制向量 二分向量使用,
|
||||
* "json": json格式的数据
|
||||
| 类型 | 说明 | 备注 |
|
||||
|------|------|------|
|
||||
| str | 字符串 | 需要 max_length |
|
||||
| int | 32位整数 | |
|
||||
| int64 | 64位整数 | |
|
||||
| bool | 布尔值 | |
|
||||
| float | 浮点数 | |
|
||||
| fvector | 浮点向量 | 需要 dim |
|
||||
| bvector | 二进制向量 | 需要 dim |
|
||||
| json | JSON | 灵活结构 |
|
||||
|
||||
一些例子
|
||||
```
|
||||
[
|
||||
# 1. 主键字段 (必选): str 类型,非自增(手动指定ID以便与业务系统关联)
|
||||
{
|
||||
"name": "id",
|
||||
"type": "str",
|
||||
"max_length": 32
|
||||
"is_primary": True,
|
||||
"auto_id": False
|
||||
},
|
||||
## 配置 (conf/config.json)
|
||||
|
||||
# 2. 向量字段 (核心): 存储 CLIP 提取的特征,768 维
|
||||
{
|
||||
"name": "video_embedding",
|
||||
"type": "fvector",
|
||||
"dim": 768
|
||||
},
|
||||
|
||||
# 3. 变长字符串: 存储视频存储路径,需指定最大长度
|
||||
{
|
||||
"name": "file_path",
|
||||
"type": "str",
|
||||
"max_length": 512
|
||||
},
|
||||
|
||||
# 4. 浮点数: 存储评估得分 (如 VBench 的综合分数)
|
||||
{
|
||||
"name": "quality_score",
|
||||
"type": "float"
|
||||
},
|
||||
|
||||
# 5. 布尔值: 标记是否已完成人工复核
|
||||
{
|
||||
"name": "is_reviewed",
|
||||
"type": "bool"
|
||||
},
|
||||
|
||||
# 6. 整数: 存储视频的时长(秒)
|
||||
{
|
||||
"name": "duration_sec",
|
||||
"type": "int"
|
||||
},
|
||||
|
||||
# 7. JSON 字段: 存储非结构化的元数据(如 VBench 的 16 个子维度细节)
|
||||
# Milvus 2.4+ 支持 JSON 动态解析查询
|
||||
{
|
||||
"name": "meta_data",
|
||||
"type": "json"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## API
|
||||
对外提供http的接口,用nginx做反向代理,提供安全,并在nginx.conf中提供客户端ip过滤,只有登记过的客户端才能使用
|
||||
|
||||
### 创建集合
|
||||
|
||||
path:/v1/createcollection
|
||||
method:POST
|
||||
headers:{
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data:{
|
||||
colname: 集合名字,必须提供
|
||||
fields: 集合字段集,请参照字段类型提供
|
||||
description:可选, 集合描述
|
||||
}
|
||||
|
||||
成功返回
|
||||
{
|
||||
"status":"SUCCEEDED"
|
||||
}
|
||||
|
||||
失败返回
|
||||
{
|
||||
"status":"FAILED",
|
||||
"error": 错误信息
|
||||
}
|
||||
|
||||
### 删除集合
|
||||
|
||||
path:/v1/dropcollection
|
||||
method:POST
|
||||
headers:{
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data:{
|
||||
colname: 集合名字,必须提供
|
||||
}
|
||||
|
||||
成功返回
|
||||
{
|
||||
"status":"SUCCEEDED"
|
||||
}
|
||||
|
||||
失败返回
|
||||
{
|
||||
"status":"FAILED",
|
||||
"error": 错误信息
|
||||
}
|
||||
|
||||
### 向集合插入一到多条记录
|
||||
|
||||
path:/v1/upsert
|
||||
method:POST
|
||||
headers:{
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data:{
|
||||
colname: 集合名字,必须提供
|
||||
data: 数据字典或数据字典数组
|
||||
}
|
||||
|
||||
成功返回
|
||||
{
|
||||
"status":"SUCCEEDED"
|
||||
}
|
||||
|
||||
失败返回
|
||||
{
|
||||
"status":"FAILED",
|
||||
"error": 错误信息
|
||||
}
|
||||
|
||||
### 删除集合一到多条记录
|
||||
|
||||
path:/v1/delete
|
||||
method:POST
|
||||
headers:{
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data:{
|
||||
colname: 集合名字,必须提供
|
||||
pks: 主键或主键数组
|
||||
description:可选, 集合描述
|
||||
}
|
||||
|
||||
成功返回
|
||||
{
|
||||
"status":"SUCCEEDED"
|
||||
}
|
||||
|
||||
失败返回
|
||||
{
|
||||
"status":"FAILED",
|
||||
"error": 错误信息
|
||||
}
|
||||
|
||||
### 查询集合数据
|
||||
|
||||
path:/v1/createcollection
|
||||
method:POST
|
||||
headers:{
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data:{
|
||||
colname: 集合名字,必须提供
|
||||
fields: 集合字段集,请参照字段类型提供
|
||||
description:可选, 集合描述
|
||||
}
|
||||
|
||||
成功返回
|
||||
{
|
||||
"status":"SUCCEEDED"
|
||||
"data": 返回数据
|
||||
}
|
||||
|
||||
失败返回
|
||||
{
|
||||
"status":"FAILED",
|
||||
"error": 错误信息
|
||||
}
|
||||
|
||||
返回数据有如下结构
|
||||
{
|
||||
"total": -1, # 不知道总共多少条符合条件的数据
|
||||
"page": 当前页(起始1)
|
||||
"pagerows": 每页记录数
|
||||
"rows": 记录数据
|
||||
}
|
||||
| 字段 | 默认值 | 说明 |
|
||||
|------|--------|------|
|
||||
| milvus_mode | lite | lite 或 standalone |
|
||||
| milvus_db | $[workdir]$/db/milvus.db | lite模式文件路径 |
|
||||
| milvus_host | 127.0.0.1:19530 | standalone模式地址 |
|
||||
| milvus_metric | COSINE | 默认距离度量 |
|
||||
| client_max_size | 100MB | 请求体大小上限 |
|
||||
|
||||
107
build.sh
107
build.sh
@ -1,45 +1,68 @@
|
||||
#!/usr/bin/env bash
|
||||
# clone from git@git.opencomputing.cn/yumoqing/vdb
|
||||
# git clone https://git.opencomputing.cn/yumoqing/vdb
|
||||
cdir=$(pwd)
|
||||
uname=$(id -un)
|
||||
gname=$(id -gn)
|
||||
sudo apt install redis-server
|
||||
python3 -m venv py3
|
||||
source py3/bin/activate
|
||||
pip install .
|
||||
mkdir $cdir/logs
|
||||
cd $cdir
|
||||
cat > $cdir/vdb.service <<EOF
|
||||
[Unit]
|
||||
Wants=systemd-networkd.service
|
||||
# vdb - Vector Database Service
|
||||
set -e
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
[Service]
|
||||
User=$uname
|
||||
Group=$gname
|
||||
Type=forking
|
||||
WorkingDirectory=$cdir
|
||||
ExecStart=$cdir/start.sh
|
||||
ExecStop=$cdir/stop.sh
|
||||
StandardOutput=append:/var/log/vdb/vdb.log
|
||||
StandardError=append:/var/log/vdb/vdb.log
|
||||
SyslogIdentifier=vdb
|
||||
SERVICE_NAME="vdb"
|
||||
PORT=8886
|
||||
PY=/data/ymq/wan22-service/py3/bin/python
|
||||
action="${1:-status}"
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
cat > $cdir/start.sh <<EOF
|
||||
#!/usr/bin/bash
|
||||
cd $cdir
|
||||
$cdir/py3/bin/python $cdir/app/vdbapp.py -p 8887 -w $cdir &
|
||||
exit 0
|
||||
EOF
|
||||
cat > $cdir/stop.sh <<EOF
|
||||
PID=\$(lsof -t -i:8887)
|
||||
kill -9 \$PID
|
||||
EOF
|
||||
chmod +x $cdir/start.sh stop.sh
|
||||
sudo mkdir /var/log/vdb
|
||||
sudo cp vdb.service /etc/systemd/system
|
||||
sudo systemctl enable vdb
|
||||
sudo systemctl restart vdb
|
||||
case "$action" in
|
||||
deploy|update)
|
||||
echo "=== $SERVICE_NAME Deploy ==="
|
||||
if [ -f ah.pid ] && kill -0 $(cat ah.pid) 2>/dev/null; then
|
||||
kill $(cat ah.pid) 2>/dev/null || true
|
||||
sleep 2
|
||||
fi
|
||||
if [ -d .git ]; then
|
||||
git pull origin master 2>/dev/null || git pull origin main 2>/dev/null || true
|
||||
fi
|
||||
mkdir -p logs db
|
||||
export PYTHONPATH="$(pwd)"
|
||||
nohup $PY app/vdbapp.py > nohup.out 2>&1 &
|
||||
echo $! > ah.pid
|
||||
echo "Started PID $(cat ah.pid) on port $PORT"
|
||||
sleep 3
|
||||
if curl -s http://localhost:$PORT/v1/listcollections > /dev/null 2>&1; then
|
||||
echo "Service healthy"
|
||||
else
|
||||
echo "WARNING: not responding yet, check nohup.out"
|
||||
tail -20 nohup.out
|
||||
fi
|
||||
;;
|
||||
stop)
|
||||
if [ -f ah.pid ]; then
|
||||
kill $(cat ah.pid) 2>/dev/null || true
|
||||
rm -f ah.pid
|
||||
echo "Stopped"
|
||||
else
|
||||
echo "Not running"
|
||||
fi
|
||||
;;
|
||||
start)
|
||||
mkdir -p logs db
|
||||
export PYTHONPATH="$(pwd)"
|
||||
nohup $PY app/vdbapp.py > nohup.out 2>&1 &
|
||||
echo $! > ah.pid
|
||||
echo "Started PID $(cat ah.pid)"
|
||||
;;
|
||||
status)
|
||||
echo "=== $SERVICE_NAME Status ==="
|
||||
if [ -f ah.pid ] && kill -0 $(cat ah.pid) 2>/dev/null; then
|
||||
echo "Process: running (PID $(cat ah.pid))"
|
||||
else
|
||||
echo "Process: not running"
|
||||
fi
|
||||
echo "Port: $PORT"
|
||||
if curl -s --max-time 3 http://localhost:$PORT/v1/listcollections > /dev/null 2>&1; then
|
||||
echo "HTTP: OK"
|
||||
else
|
||||
echo "HTTP: not responding"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {deploy|update|stop|start|status}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
{
|
||||
"vdb_type": "milvus",
|
||||
"filesroot": "$[workdir]$/files",
|
||||
"vdb_type": "milvus",
|
||||
"milvus_mode": "lite",
|
||||
"milvus_db": "$[workdir]$/db/milvus.db",
|
||||
"milvus_host": "127.0.0.1:19530",
|
||||
"milvus_token": "",
|
||||
"milvus_dbname": "default",
|
||||
"milvus_metric": "COSINE",
|
||||
"filesroot": "$[workdir]$/files",
|
||||
"logger": {
|
||||
"name": "vdb",
|
||||
"levelname": "info",
|
||||
@ -11,13 +16,13 @@
|
||||
"paths": [
|
||||
["$[workdir]$/wwwroot", ""]
|
||||
],
|
||||
"client_max_size": 10000,
|
||||
"client_max_size": 104857600,
|
||||
"host": "0.0.0.0",
|
||||
"port": 8886,
|
||||
"coding": "utf-8",
|
||||
"indexes": [
|
||||
"index.html",
|
||||
"index.dspy",
|
||||
"index.dspy",
|
||||
"index.ui"
|
||||
],
|
||||
"startswiths": [
|
||||
@ -34,16 +39,32 @@
|
||||
"registerfunction": "drop_collection"
|
||||
},
|
||||
{
|
||||
"leading": "/v1/iupsert",
|
||||
"leading": "/v1/listcollections",
|
||||
"registerfunction": "list_collections"
|
||||
},
|
||||
{
|
||||
"leading": "/v1/collectionstats",
|
||||
"registerfunction": "collection_stats"
|
||||
},
|
||||
{
|
||||
"leading": "/v1/upsert",
|
||||
"registerfunction": "upsert"
|
||||
},
|
||||
{
|
||||
"leading": "/v1/batchinsert",
|
||||
"registerfunction": "batch_insert"
|
||||
},
|
||||
{
|
||||
"leading": "/v1/delete",
|
||||
"registerfunction": "delete"
|
||||
},
|
||||
{
|
||||
"leading": "/v1/query",
|
||||
"registerfunction": "query"
|
||||
"leading": "/v1/query",
|
||||
"registerfunction": "query"
|
||||
},
|
||||
{
|
||||
"leading": "/v1/batchquery",
|
||||
"registerfunction": "batch_query"
|
||||
},
|
||||
{
|
||||
"leading": "/docs",
|
||||
@ -57,5 +78,5 @@
|
||||
[".dspy", "dspy"],
|
||||
[".md", "md"]
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,18 +4,18 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vdb"
|
||||
version = "0.1.0"
|
||||
description = "向量数据库服务"
|
||||
version = "0.2.0"
|
||||
description = "Vector Database Service (Milvus Lite + Standalone)"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"apppublic",
|
||||
"sqlor",
|
||||
"ahserver",
|
||||
"milvus-lite",
|
||||
"pymilvus"
|
||||
"milvus-lite",
|
||||
"pymilvus"
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."] # 声明在哪个目录下查找包,默认是当前目录
|
||||
include = ["vdb"] # 包含哪些包
|
||||
where = ["."]
|
||||
include = ["vdb"]
|
||||
|
||||
@ -1,24 +1,44 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseVDB:
|
||||
def query(self, collection_name,
|
||||
vector=None,
|
||||
expr=None,
|
||||
pagerows=80,
|
||||
page=1,
|
||||
output_fields=None):
|
||||
pass
|
||||
|
||||
def get_db_type(dtype: str):
|
||||
return self.dbtypes.get(dtype)
|
||||
class BaseVDB(ABC):
|
||||
"""向量数据库抽象基类"""
|
||||
|
||||
def upsert(self, collection_name, data_dicts, partition_name=None):
|
||||
pass
|
||||
@abstractmethod
|
||||
def create_collection(self, collection_name, fields_config, description="", **kwargs):
|
||||
pass
|
||||
|
||||
def delete(self, collection_name, pks, partition_name=None):
|
||||
pass
|
||||
|
||||
def drop_collection(self, collection_name):
|
||||
pass
|
||||
|
||||
def create_collection(self, collection_name, fields_config, description=""):
|
||||
pass
|
||||
@abstractmethod
|
||||
def drop_collection(self, collection_name):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_collections(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collection_stats(self, collection_name):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upsert(self, collection_name, data_dicts, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def batch_insert(self, collection_name, data_dicts, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, collection_name, pks, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def batch_query(self, collection_name, vectors, expr=None, limit=10, output_fields=None, **kwargs):
|
||||
pass
|
||||
|
||||
def get_db_type(self, dtype: str):
|
||||
return self.dbtypes.get(dtype) if hasattr(self, "dbtypes") else None
|
||||
|
||||
251
vdb/init.py
251
vdb/init.py
@ -1,5 +1,4 @@
|
||||
from traceback import format_exc
|
||||
from functools import partial
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from appPublic.worker import awaitify
|
||||
from appPublic.registerfunction import RegisterFunction
|
||||
@ -7,118 +6,168 @@ from appPublic.log import debug, exception
|
||||
from appPublic.jsonConfig import getConfig
|
||||
from .milvus import MilvusManager
|
||||
|
||||
|
||||
def ownerparting(data):
|
||||
if data.get('ownerid'):
|
||||
return data.get('ownerid')
|
||||
return '_default'
|
||||
if data.get("ownerid"):
|
||||
return data.get("ownerid")
|
||||
return "_default"
|
||||
|
||||
|
||||
async def create_collection(request, params_kw, *args, **kwargs):
|
||||
colname = params_kw.colname
|
||||
fields = params_kw.fields
|
||||
description = params_kw.description or ""
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.create_collection)
|
||||
try:
|
||||
r = await f(colname, fields, description)
|
||||
return {
|
||||
"status":"SUCCEEDED",
|
||||
}
|
||||
except Exception as e:
|
||||
exception(f'{e}, {format_exc()}')
|
||||
return {
|
||||
"status":"FAILED",
|
||||
"error": f"{e}"
|
||||
}
|
||||
colname = params_kw.colname
|
||||
fields = params_kw.fields
|
||||
description = params_kw.description or ""
|
||||
metric = params_kw.metric
|
||||
index_type = params_kw.index_type or "HNSW"
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.create_collection)
|
||||
try:
|
||||
await f(colname, fields, description, metric=metric, index_type=index_type)
|
||||
return {"status": "SUCCEEDED"}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {"status": "FAILED", "error": str(e)}
|
||||
|
||||
|
||||
async def drop_collection(request, params_kw, *args, **kwargs):
|
||||
colname = params_kw.colname
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.drop_collection)
|
||||
try:
|
||||
r = await f(colname)
|
||||
return {
|
||||
"status": "SUCCEEDED"
|
||||
}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {
|
||||
"status": "FAILED",
|
||||
"error": f"{e}"
|
||||
}
|
||||
colname = params_kw.colname
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.drop_collection)
|
||||
try:
|
||||
await f(colname)
|
||||
return {"status": "SUCCEEDED"}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {"status": "FAILED", "error": str(e)}
|
||||
|
||||
|
||||
async def list_collections(request, params_kw, *args, **kwargs):
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.list_collections)
|
||||
try:
|
||||
result = await f()
|
||||
return {"status": "SUCCEEDED", "collections": result}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {"status": "FAILED", "error": str(e)}
|
||||
|
||||
|
||||
async def collection_stats(request, params_kw, *args, **kwargs):
|
||||
colname = params_kw.colname
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.collection_stats)
|
||||
try:
|
||||
result = await f(colname)
|
||||
if result is None:
|
||||
return {"status": "FAILED", "error": f"Collection '{colname}' not found"}
|
||||
return {"status": "SUCCEEDED", "data": result}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {"status": "FAILED", "error": str(e)}
|
||||
|
||||
|
||||
async def upsert(request, params_kw, *args, **kwargs):
|
||||
colname = params_kw.colname
|
||||
data = params_kw.data
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.upsert)
|
||||
try:
|
||||
r = await f(colname, data)
|
||||
return {
|
||||
"status": "SUCCEEDED"
|
||||
}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {
|
||||
"status": "FAILED",
|
||||
"error": f"{e}"
|
||||
}
|
||||
colname = params_kw.colname
|
||||
data = params_kw.data
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.upsert)
|
||||
try:
|
||||
await f(colname, data)
|
||||
return {"status": "SUCCEEDED"}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {"status": "FAILED", "error": str(e)}
|
||||
|
||||
|
||||
async def batch_insert(request, params_kw, *args, **kwargs):
|
||||
colname = params_kw.colname
|
||||
data = params_kw.data
|
||||
flush = params_kw.flush if params_kw.flush is not None else True
|
||||
partition_name = params_kw.partition_name
|
||||
if not isinstance(data, list):
|
||||
return {"status": "FAILED", "error": "data must be a list"}
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.batch_insert)
|
||||
try:
|
||||
await f(colname, data, flush=flush, partition_name=partition_name)
|
||||
return {"status": "SUCCEEDED", "count": len(data)}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {"status": "FAILED", "error": str(e)}
|
||||
|
||||
|
||||
async def delete(request, params_kw, *args, **kwargs):
|
||||
colname = params_kw.colname
|
||||
pks = params_kw.pks
|
||||
colname = params_kw.colname
|
||||
pks = params_kw.pks
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.delete)
|
||||
try:
|
||||
await f(colname, pks)
|
||||
return {"status": "SUCCEEDED"}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {"status": "FAILED", "error": str(e)}
|
||||
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.delete)
|
||||
try:
|
||||
r = await f(colname, pks)
|
||||
return {
|
||||
"status": "SUCCEEDED"
|
||||
}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {
|
||||
"status": "FAILED",
|
||||
"error": f"{e}"
|
||||
}
|
||||
|
||||
async def query(request, params_kw, *args, **kwargs):
|
||||
colname = params_kw.colname
|
||||
vector = params_kw.vector
|
||||
expr = params_kw.expr
|
||||
pagerows = params_kw.pagerows or 80
|
||||
page = params_kw.page or 1
|
||||
output_fields = params_kw.output_fields
|
||||
colname = params_kw.colname
|
||||
vector = params_kw.vector
|
||||
expr = params_kw.expr
|
||||
pagerows = params_kw.pagerows or 80
|
||||
page = params_kw.page or 1
|
||||
output_fields = params_kw.output_fields
|
||||
metric = params_kw.metric
|
||||
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.query)
|
||||
try:
|
||||
result = await f(colname, vector=vector, expr=expr, pagerows=pagerows,
|
||||
page=page, output_fields=output_fields, metric=metric)
|
||||
return {"status": "SUCCEEDED", "data": result}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {"status": "FAILED", "error": str(e)}
|
||||
|
||||
|
||||
async def batch_query(request, params_kw, *args, **kwargs):
|
||||
colname = params_kw.colname
|
||||
vectors = params_kw.vectors
|
||||
expr = params_kw.expr
|
||||
limit = params_kw.limit or 10
|
||||
output_fields = params_kw.output_fields
|
||||
metric = params_kw.metric
|
||||
|
||||
if not isinstance(vectors, list) or not vectors:
|
||||
return {"status": "FAILED", "error": "vectors must be a non-empty list"}
|
||||
|
||||
env = request._run_ns
|
||||
f = awaitify(env.vdb.batch_query)
|
||||
try:
|
||||
result = await f(colname, vectors, expr=expr, limit=limit,
|
||||
output_fields=output_fields, metric=metric)
|
||||
return {"status": "SUCCEEDED", "data": result}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {"status": "FAILED", "error": str(e)}
|
||||
|
||||
env = request._run_ns
|
||||
f1 = awaitify(env.vdb.drop_collection)
|
||||
try:
|
||||
f = partial(f1, colname, vector=vector, expr=expr, pagerows=pagerows, page=page, output_fields=output_fields)
|
||||
r = await f()
|
||||
return {
|
||||
"status": "SUCCEEDED",
|
||||
"data": r
|
||||
}
|
||||
except Exception as e:
|
||||
exception(f"{e}, {format_exc()}")
|
||||
return {
|
||||
"status": "FAILED",
|
||||
"error": f"{e}"
|
||||
}
|
||||
|
||||
def load_vdb():
|
||||
config = getConfig()
|
||||
vdb = None
|
||||
vdb_type = config.vdb_type
|
||||
if vdb_type == 'milvus':
|
||||
vdb = MilvusManager(partitionize=ownerparting)
|
||||
env = ServerEnv()
|
||||
env.vdb = vdb
|
||||
rf = RegisterFunction()
|
||||
rf.register('create_collection', create_collection)
|
||||
rf.register('drop_collection', drop_collection)
|
||||
rf.register('upsert', upsert)
|
||||
rf.register('delete', delete)
|
||||
rf.register('query', query)
|
||||
config = getConfig()
|
||||
vdb = None
|
||||
vdb_type = config.vdb_type
|
||||
if vdb_type == "milvus":
|
||||
vdb = MilvusManager(partitionize=ownerparting)
|
||||
env = ServerEnv()
|
||||
env.vdb = vdb
|
||||
rf = RegisterFunction()
|
||||
rf.register("create_collection", create_collection)
|
||||
rf.register("drop_collection", drop_collection)
|
||||
rf.register("list_collections", list_collections)
|
||||
rf.register("collection_stats", collection_stats)
|
||||
rf.register("upsert", upsert)
|
||||
rf.register("batch_insert", batch_insert)
|
||||
rf.register("delete", delete)
|
||||
rf.register("query", query)
|
||||
rf.register("batch_query", batch_query)
|
||||
|
||||
487
vdb/milvus.py
487
vdb/milvus.py
@ -4,224 +4,309 @@ from appPublic.log import info, debug, exception, error
|
||||
from threading import Lock
|
||||
|
||||
from pymilvus import (
|
||||
connections, FieldSchema, CollectionSchema,
|
||||
DataType, Collection, utility, Partition
|
||||
connections, FieldSchema, CollectionSchema,
|
||||
DataType, Collection, utility
|
||||
)
|
||||
from .basevdb import BaseVDB
|
||||
|
||||
|
||||
class MilvusManager(BaseVDB):
|
||||
_instance = None
|
||||
_lock = Lock()
|
||||
dbtypes = {
|
||||
"str": DataType.VARCHAR,
|
||||
"int": DataType.INT32,
|
||||
"bool": DataType.BOOL,
|
||||
"float": DataType.FLOAT,
|
||||
"fvector": DataType.FLOAT_VECTOR,
|
||||
"bvector": DataType.BINARY_VECTOR,
|
||||
"json": DataType.JSON
|
||||
}
|
||||
def __new__(cls, *args, **kwargs):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(MilvusManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
_instance = None
|
||||
_lock = Lock()
|
||||
dbtypes = {
|
||||
"str": DataType.VARCHAR,
|
||||
"int": DataType.INT32,
|
||||
"int64": DataType.INT64,
|
||||
"bool": DataType.BOOL,
|
||||
"float": DataType.FLOAT,
|
||||
"fvector": DataType.FLOAT_VECTOR,
|
||||
"bvector": DataType.BINARY_VECTOR,
|
||||
"json": DataType.JSON,
|
||||
}
|
||||
metric_map = {
|
||||
"L2": "L2", "l2": "L2",
|
||||
"IP": "IP", "ip": "IP",
|
||||
"COSINE": "COSINE", "cosine": "COSINE",
|
||||
}
|
||||
|
||||
def __init__(self, partitionize=None):
|
||||
self.partitionize = partitionize
|
||||
if self._initialized:
|
||||
return
|
||||
try:
|
||||
config = getConfig()
|
||||
self.db_path = config.milvus_db
|
||||
debug(f"dbpath: {self.db_path}")
|
||||
except KeyError as e:
|
||||
error(f"配置文件缺少必要字段: {str(e)}")
|
||||
raise RuntimeError(f"配置文件缺少必要字段: {str(e)}")
|
||||
self._initialize_connection()
|
||||
self._initialized = True
|
||||
info(f"MilvusDBConnection initialized with db_path: {self.db_path}")
|
||||
def __new__(cls, *args, **kwargs):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(MilvusManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def _initialize_connection(self):
|
||||
"""初始化 Milvus 连接,确保单一连接"""
|
||||
try:
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
debug(f"db_dir: {db_dir}")
|
||||
if not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
debug(f"创建 Milvus 目录: {db_dir}")
|
||||
if not os.access(db_dir, os.W_OK):
|
||||
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
|
||||
debug(f"不可写")
|
||||
if not connections.has_connection(self.db_path):
|
||||
connections.connect(self.db_path, uri=self.db_path)
|
||||
debug(f"已连接到 Milvus Lite,路径: {self.db_path}")
|
||||
else:
|
||||
debug("已存在 Milvus 连接,跳过重复连接")
|
||||
except Exception as e:
|
||||
error(f"连接 Milvus 失败: {str(e)}")
|
||||
raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
|
||||
def __init__(self, partitionize=None):
|
||||
self.partitionize = partitionize
|
||||
if self._initialized:
|
||||
return
|
||||
try:
|
||||
config = getConfig()
|
||||
self.alias = "default"
|
||||
self.mode = getattr(config, "milvus_mode", "lite")
|
||||
if self.mode == "standalone":
|
||||
self.host = config.milvus_host
|
||||
self.token = getattr(config, "milvus_token", "")
|
||||
self.db_name = getattr(config, "milvus_dbname", "default")
|
||||
else:
|
||||
self.db_path = config.milvus_db
|
||||
self.default_metric = getattr(config, "milvus_metric", "COSINE")
|
||||
self._initialize_connection()
|
||||
self._initialized = True
|
||||
info(f"MilvusManager initialized: mode={self.mode}, alias={self.alias}")
|
||||
except Exception as e:
|
||||
error(f"MilvusManager init failed: {e}")
|
||||
raise RuntimeError(f"MilvusManager init failed: {e}")
|
||||
|
||||
# --- 集合管理 ---
|
||||
def create_collection(self, collection_name, fields_config, description=""):
|
||||
"""
|
||||
打开或创建集合。如果已存在则返回对象;不存在则根据 config 创建。
|
||||
:param fields_config: 格式: [{"name": "id", "type": DataType.INT64, "is_primary": True}, ...]
|
||||
"""
|
||||
if utility.has_collection(collection_name, using=self.alias):
|
||||
# print(f"📦 集合 '{collection_name}' 已存在,直接加载。")
|
||||
return Collection(collection_name, using=self.alias)
|
||||
def _initialize_connection(self):
|
||||
try:
|
||||
if self.mode == "standalone":
|
||||
parts = self.host.split(":")
|
||||
host = parts[0]
|
||||
port = parts[1] if len(parts) > 1 else "19530"
|
||||
uri = f"http://{host}:{port}"
|
||||
if not connections.has_connection(self.alias):
|
||||
kw = {"uri": uri, "alias": self.alias}
|
||||
if self.token:
|
||||
kw["token"] = self.token
|
||||
if self.db_name and self.db_name != "default":
|
||||
kw["db_name"] = self.db_name
|
||||
connections.connect(**kw)
|
||||
info(f"Connected to Milvus Standalone: {uri}")
|
||||
else:
|
||||
debug("Milvus Standalone connection already exists")
|
||||
else:
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
if not connections.has_connection(self.alias):
|
||||
connections.connect(alias=self.alias, uri=self.db_path)
|
||||
info(f"Connected to Milvus Lite: {self.db_path}")
|
||||
else:
|
||||
debug("Milvus Lite connection already exists")
|
||||
except Exception as e:
|
||||
error(f"Milvus connect failed: {e}")
|
||||
raise RuntimeError(f"Milvus connect failed: {e}")
|
||||
|
||||
fields = []
|
||||
for cfg in fields_config:
|
||||
field = FieldSchema(
|
||||
name=cfg['name'],
|
||||
dtype=self.get_db_type(cfg['type']),
|
||||
is_primary=cfg.get('is_primary', False),
|
||||
auto_id=cfg.get('auto_id', False),
|
||||
dim=cfg.get('dim'),
|
||||
max_length=cfg.get('max_length')
|
||||
)
|
||||
fields.append(field)
|
||||
# === Collection Management ===
|
||||
|
||||
schema = CollectionSchema(fields, description=description)
|
||||
collection = Collection(name=collection_name, schema=schema, using=self.alias)
|
||||
|
||||
# 自动为向量字段创建索引 (内网推荐 HNSW)
|
||||
for cfg in fields_config:
|
||||
dbtype = self.get_db_type(cfg['type'])
|
||||
if dbtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
|
||||
index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 16, "efConstruction": 64}
|
||||
}
|
||||
collection.create_index(field_name=cfg['name'], index_params=index_params)
|
||||
|
||||
# print(f"🚀 集合 '{collection_name}' 创建并初始化索引完成。")
|
||||
return collection
|
||||
def list_collections(self):
|
||||
return utility.list_collections(using=self.alias)
|
||||
|
||||
def drop_collection(self, collection_name):
|
||||
"""物理删除集合"""
|
||||
if utility.has_collection(collection_name, using=self.alias):
|
||||
utility.drop_collection(collection_name, using=self.alias)
|
||||
print(f"🗑️ 集合 '{collection_name}' 已从磁盘删除。")
|
||||
def collection_stats(self, collection_name):
|
||||
if not utility.has_collection(collection_name, using=self.alias):
|
||||
return None
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
stats = {"name": collection_name, "row_count": col.num_entities}
|
||||
stats["fields"] = [
|
||||
{"name": f.name, "type": str(f.dtype), "is_primary": f.is_primary}
|
||||
for f in col.schema.fields
|
||||
]
|
||||
try:
|
||||
stats["indexes"] = [
|
||||
{"field": idx.field_name, "type": idx.params.get("index_type", "?")}
|
||||
for idx in col.indexes
|
||||
]
|
||||
except Exception:
|
||||
stats["indexes"] = []
|
||||
return stats
|
||||
|
||||
# --- 内存与分区优化 ---
|
||||
def load_segmented(self, collection_name, partition_names=None):
|
||||
"""分段加载:只加载特定分区到内存以保护内网服务器内存"""
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
col.load(partition_names=partition_names)
|
||||
print(f"🧠 已加载集合数据(分区: {partition_names if partition_names else '全量'})到内存。")
|
||||
def create_collection(self, collection_name, fields_config, description="",
|
||||
metric=None, index_type="HNSW", index_params_extra=None):
|
||||
if utility.has_collection(collection_name, using=self.alias):
|
||||
return Collection(collection_name, using=self.alias)
|
||||
|
||||
def release_memory(self, collection_name):
|
||||
"""释放内存:在执行 VBench 推理等高显存任务前调用"""
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
col.release()
|
||||
print(f"♻️ 集合 '{collection_name}' 已释放内存占用。")
|
||||
metric = self.metric_map.get(metric, self.default_metric) if metric else self.default_metric
|
||||
|
||||
# --- 数据操作 (DML) ---
|
||||
def _validate_and_format(self, col, data_dicts):
|
||||
"""私有工具:校验字典完整性并转换为列式数据"""
|
||||
schema = col.schema
|
||||
# 找出需要填充的字段名(排除自增ID)
|
||||
required_fields = [f.name for f in schema.fields if not f.auto_id]
|
||||
|
||||
# 提取向量字段信息用于维度校验
|
||||
vec_info = {f.name: f.params['dim'] for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR}
|
||||
fields = []
|
||||
for cfg in fields_config:
|
||||
field = FieldSchema(
|
||||
name=cfg["name"],
|
||||
dtype=self.get_db_type(cfg["type"]),
|
||||
is_primary=cfg.get("is_primary", False),
|
||||
auto_id=cfg.get("auto_id", False),
|
||||
dim=cfg.get("dim"),
|
||||
max_length=cfg.get("max_length"),
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
columnar_data = {name: [] for name in required_fields}
|
||||
for i, entry in enumerate(data_dicts):
|
||||
for field in required_fields:
|
||||
if field not in entry:
|
||||
raise ValueError(f"记录 {i} 缺失必填字段: {field}")
|
||||
# 向量维度校验
|
||||
if field in vec_info and len(entry[field]) != vec_info[field]:
|
||||
raise ValueError(f"记录 {i} 向量维度错误: 预期 {vec_info[field]}, 实际 {len(entry[field])}")
|
||||
columnar_data[field].append(entry[field])
|
||||
|
||||
return [columnar_data[f] for f in required_fields]
|
||||
schema = CollectionSchema(fields, description=description)
|
||||
collection = Collection(name=collection_name, schema=schema, using=self.alias)
|
||||
|
||||
def upsert(self, collection_name, data_dicts):
|
||||
"""通用 Upsert:支持字典列表输入,自动识别主键更新"""
|
||||
pks = [item['id'] for item in data_dicts]
|
||||
self.delete(collection_name, pks)
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
if self.partitionize is None:
|
||||
formatted_data = self._validate_and_format(col, data_dicts)
|
||||
res = col.upsert(formatted_data)
|
||||
col.flush() # 内网环境强制落盘以防数据丢失
|
||||
return res
|
||||
grouped_data = {}
|
||||
for entry in data_dicts:
|
||||
p_name = str(partition_func(entry))
|
||||
if p_name not in grouped_data:
|
||||
grouped_data[p_name] = []
|
||||
grouped_data[p_name].append(entry)
|
||||
for cfg in fields_config:
|
||||
dbtype = self.get_db_type(cfg["type"])
|
||||
if dbtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
|
||||
idx_params = {"index_type": index_type, "metric_type": metric, "params": {"M": 16, "efConstruction": 64}}
|
||||
if index_params_extra:
|
||||
idx_params["params"].update(index_params_extra)
|
||||
collection.create_index(field_name=cfg["name"], index_params=idx_params)
|
||||
|
||||
for p_name, p_data in grouped_data.items():
|
||||
# 自动维护分区
|
||||
if p_name != "_default" and not col.has_partition(p_name):
|
||||
col.create_partition(p_name)
|
||||
print(f"🚩 自动创建新分区: {p_name}")
|
||||
|
||||
# 格式化数据
|
||||
formatted = self._validate_and_format(col, p_data)
|
||||
|
||||
# 执行 Upsert
|
||||
col.upsert(formatted, partition_name=p_name)
|
||||
col.flush()
|
||||
info(f"Collection '{collection_name}' created, metric={metric}, index={index_type}")
|
||||
return collection
|
||||
|
||||
def delete(self, collection_name, pks, partition_name=None):
|
||||
"""删除指定主键记录"""
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
expr = f"id in {pks}" if isinstance(pks, list) else f"id == {pks}"
|
||||
col.delete(expr, partition_name=partition_name)
|
||||
print(f"✂️ 已删除主键为 {pks} 的记录。")
|
||||
def drop_collection(self, collection_name):
|
||||
if utility.has_collection(collection_name, using=self.alias):
|
||||
utility.drop_collection(collection_name, using=self.alias)
|
||||
info(f"Collection '{collection_name}' dropped")
|
||||
|
||||
# --- 高级组合检索 ---
|
||||
def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None):
|
||||
"""
|
||||
组合查询接口:支持向量近似搜索 + 标量过滤 + 分页
|
||||
:param vector: 目标向量,若为 None 则退化为纯标量查询
|
||||
:param expr: 过滤条件,如 "score > 0.8 and lang == 'cmn'"
|
||||
:param offset: 分页偏移量
|
||||
"""
|
||||
offset = (page - 1) * pagerows
|
||||
limit = pagerows
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
# 确保数据已加载
|
||||
if utility.get_query_segment_info(collection_name, using=self.alias) == []:
|
||||
self.load_segmented(collection_name)
|
||||
# === Memory ===
|
||||
|
||||
if output_fields is None:
|
||||
output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR]
|
||||
def load_segmented(self, collection_name, partition_names=None):
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
col.load(partition_names=partition_names)
|
||||
debug(f"Loaded '{collection_name}' partitions={partition_names}")
|
||||
|
||||
if vector is not None:
|
||||
# 自动寻找向量字段名
|
||||
vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR)
|
||||
search_params = {"metric_type": "L2", "params": {"ef": 64}, "offset": offset}
|
||||
|
||||
return {
|
||||
"total": -1,
|
||||
"page": page,
|
||||
"pagerows": pagerows,
|
||||
"rows": col.search(
|
||||
data=[vector],
|
||||
anns_field=vec_field,
|
||||
param=search_params,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
expr=expr,
|
||||
output_fields=output_fields
|
||||
)
|
||||
}
|
||||
else:
|
||||
# 纯标量查询
|
||||
return {
|
||||
"total": -1,
|
||||
"page": page,
|
||||
"pagerows": pagerows,
|
||||
"rows": col.query(expr=expr, limit=limit, offset=offset, output_fields=output_fields)
|
||||
}
|
||||
def release_memory(self, collection_name):
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
col.release()
|
||||
debug(f"Released '{collection_name}' from memory")
|
||||
|
||||
# === DML ===
|
||||
|
||||
def _validate_and_format(self, col, data_dicts):
|
||||
schema = col.schema
|
||||
required = [f.name for f in schema.fields if not f.auto_id]
|
||||
vec_info = {f.name: f.params.get("dim", 0) for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR}
|
||||
columnar = {name: [] for name in required}
|
||||
for i, entry in enumerate(data_dicts):
|
||||
for field in required:
|
||||
if field not in entry:
|
||||
raise ValueError(f"Record {i} missing field: {field}")
|
||||
if field in vec_info and entry[field] is not None:
|
||||
if len(entry[field]) != vec_info[field]:
|
||||
raise ValueError(f"Record {i} vector dim mismatch: expected {vec_info[field]}, got {len(entry[field])}")
|
||||
columnar[field].append(entry[field])
|
||||
return [columnar[f] for f in required]
|
||||
|
||||
def upsert(self, collection_name, data_dicts, **kwargs):
|
||||
if not isinstance(data_dicts, list):
|
||||
data_dicts = [data_dicts]
|
||||
pks = [item.get("id") for item in data_dicts if "id" in item]
|
||||
if pks:
|
||||
self.delete(collection_name, pks)
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
if self.partitionize is None:
|
||||
formatted = self._validate_and_format(col, data_dicts)
|
||||
res = col.upsert(formatted)
|
||||
col.flush()
|
||||
return res
|
||||
grouped = {}
|
||||
for entry in data_dicts:
|
||||
p_name = str(self.partitionize(entry))
|
||||
grouped.setdefault(p_name, []).append(entry)
|
||||
for p_name, p_data in grouped.items():
|
||||
if p_name != "_default" and not col.has_partition(p_name):
|
||||
col.create_partition(p_name)
|
||||
formatted = self._validate_and_format(col, p_data)
|
||||
col.upsert(formatted, partition_name=p_name)
|
||||
col.flush()
|
||||
|
||||
def batch_insert(self, collection_name, data_dicts, flush=True, partition_name=None, **kwargs):
|
||||
if not isinstance(data_dicts, list) or not data_dicts:
|
||||
return None
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
formatted = self._validate_and_format(col, data_dicts)
|
||||
res = col.insert(formatted, partition_name=partition_name)
|
||||
if flush:
|
||||
col.flush()
|
||||
return res
|
||||
|
||||
def delete(self, collection_name, pks, partition_name=None, **kwargs):
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
if isinstance(pks, list):
|
||||
if pks and isinstance(pks[0], str):
|
||||
pk_str = ", ".join(f'"{p}"' for p in pks)
|
||||
expr = f"id in [{pk_str}]"
|
||||
else:
|
||||
expr = f"id in {pks}"
|
||||
else:
|
||||
expr = f'id == "{pks}"' if isinstance(pks, str) else f"id == {pks}"
|
||||
col.delete(expr, partition_name=partition_name)
|
||||
col.flush()
|
||||
|
||||
# === Search ===
|
||||
|
||||
def _ensure_loaded(self, collection_name):
|
||||
try:
|
||||
seg_info = utility.get_query_segment_info(collection_name, using=self.alias)
|
||||
if not seg_info:
|
||||
self.load_segmented(collection_name)
|
||||
except Exception:
|
||||
try:
|
||||
self.load_segmented(collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1,
|
||||
output_fields=None, metric=None, **kwargs):
|
||||
offset = (page - 1) * pagerows
|
||||
limit = pagerows
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
self._ensure_loaded(collection_name)
|
||||
|
||||
if output_fields is None:
|
||||
output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR]
|
||||
|
||||
if vector is not None:
|
||||
vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR)
|
||||
m = self.metric_map.get(metric, self.default_metric) if metric else self.default_metric
|
||||
search_params = {"metric_type": m, "params": {"ef": 64}, "offset": offset}
|
||||
hits = col.search(
|
||||
data=[vector],
|
||||
anns_field=vec_field,
|
||||
param=search_params,
|
||||
limit=limit,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
)
|
||||
rows = []
|
||||
for hit in hits[0]:
|
||||
row = {"id": hit.id, "score": hit.score}
|
||||
for field in output_fields:
|
||||
if field not in row:
|
||||
try:
|
||||
row[field] = hit.entity.get(field)
|
||||
except Exception:
|
||||
pass
|
||||
rows.append(row)
|
||||
return {"total": -1, "page": page, "pagerows": pagerows, "rows": rows}
|
||||
else:
|
||||
rows = col.query(expr=expr, limit=limit, offset=offset, output_fields=output_fields)
|
||||
return {"total": -1, "page": page, "pagerows": pagerows, "rows": rows}
|
||||
|
||||
def batch_query(self, collection_name, vectors, expr=None, limit=10,
|
||||
output_fields=None, metric=None, **kwargs):
|
||||
if not vectors:
|
||||
return []
|
||||
col = Collection(collection_name, using=self.alias)
|
||||
self._ensure_loaded(collection_name)
|
||||
|
||||
if output_fields is None:
|
||||
output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR]
|
||||
|
||||
vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR)
|
||||
m = self.metric_map.get(metric, self.default_metric) if metric else self.default_metric
|
||||
search_params = {"metric_type": m, "params": {"ef": 64}}
|
||||
|
||||
results = col.search(
|
||||
data=vectors,
|
||||
anns_field=vec_field,
|
||||
param=search_params,
|
||||
limit=limit,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
)
|
||||
all_rows = []
|
||||
for hits in results:
|
||||
rows = []
|
||||
for hit in hits:
|
||||
row = {"id": hit.id, "score": hit.score}
|
||||
for field in output_fields:
|
||||
if field not in row:
|
||||
try:
|
||||
row[field] = hit.entity.get(field)
|
||||
except Exception:
|
||||
pass
|
||||
rows.append(row)
|
||||
all_rows.append(rows)
|
||||
return all_rows
|
||||
Loading…
x
Reference in New Issue
Block a user