bugfix
This commit is contained in:
parent
bdc36e32d5
commit
f9ee939b0f
21
vdb/basevdb.py
Normal file
21
vdb/basevdb.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
|
||||||
|
class BaseVDB:
|
||||||
|
def query(self, collection_name,
|
||||||
|
vector=None,
|
||||||
|
expr=None,
|
||||||
|
pagerows=80,
|
||||||
|
page=1,
|
||||||
|
output_fields=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def upsert(self, collection_name, data_dicts, partition_name=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete(self, collection_name, pks, partition_name=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def drop_collection(self, collection_name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_collection(self, collection_name, fields_config, description=""):
|
||||||
|
pass
|
||||||
205
vdb/milvus.py
Normal file
205
vdb/milvus.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
import os
|
||||||
|
from appPublic.jsonConfig import getConfig
|
||||||
|
from appPublic.log import debug, exception, error
|
||||||
|
|
||||||
|
from pymilvus import (
|
||||||
|
connections, FieldSchema, CollectionSchema,
|
||||||
|
DataType, Collection, utility, Partition
|
||||||
|
)
|
||||||
|
|
||||||
|
class MilvusManager(BaseVDB):
|
||||||
|
_instance = None
|
||||||
|
_lock = Lock()
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(MilvusDBConnection, cls).__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, partitionize=None):
|
||||||
|
self.partitionize = partitionize
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
config = getConfig()
|
||||||
|
self.db_path = config['milvus_db']
|
||||||
|
debug(f"dbpath: {self.db_path}")
|
||||||
|
except KeyError as e:
|
||||||
|
error(f"配置文件缺少必要字段: {str(e)}")
|
||||||
|
raise RuntimeError(f"配置文件缺少必要字段: {str(e)}")
|
||||||
|
self._initialize_connection()
|
||||||
|
self._initialized = True
|
||||||
|
info(f"MilvusDBConnection initialized with db_path: {self.db_path}")
|
||||||
|
|
||||||
|
def _initialize_connection(self):
|
||||||
|
"""初始化 Milvus 连接,确保单一连接"""
|
||||||
|
try:
|
||||||
|
db_dir = os.path.dirname(self.db_path)
|
||||||
|
debug(f"db_dir: {db_dir}")
|
||||||
|
if not os.path.exists(db_dir):
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
debug(f"创建 Milvus 目录: {db_dir}")
|
||||||
|
if not os.access(db_dir, os.W_OK):
|
||||||
|
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
|
||||||
|
debug(f"不可写")
|
||||||
|
if not connections.has_connection(self.db_path):
|
||||||
|
connections.connect(self.db_path, uri=self.db_path)
|
||||||
|
debug(f"已连接到 Milvus Lite,路径: {self.db_path}")
|
||||||
|
else:
|
||||||
|
debug("已存在 Milvus 连接,跳过重复连接")
|
||||||
|
except Exception as e:
|
||||||
|
error(f"连接 Milvus 失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
|
||||||
|
|
||||||
|
# --- 集合管理 ---
|
||||||
|
def create_collection(self, collection_name, fields_config, description=""):
|
||||||
|
"""
|
||||||
|
打开或创建集合。如果已存在则返回对象;不存在则根据 config 创建。
|
||||||
|
:param fields_config: 格式: [{"name": "id", "type": DataType.INT64, "is_primary": True}, ...]
|
||||||
|
"""
|
||||||
|
if utility.has_collection(collection_name, using=self.alias):
|
||||||
|
# print(f"📦 集合 '{collection_name}' 已存在,直接加载。")
|
||||||
|
return Collection(collection_name, using=self.alias)
|
||||||
|
|
||||||
|
fields = []
|
||||||
|
for cfg in fields_config:
|
||||||
|
field = FieldSchema(
|
||||||
|
name=cfg['name'],
|
||||||
|
dtype=cfg['type'],
|
||||||
|
is_primary=cfg.get('is_primary', False),
|
||||||
|
auto_id=cfg.get('auto_id', False),
|
||||||
|
dim=cfg.get('dim'),
|
||||||
|
max_length=cfg.get('max_length')
|
||||||
|
)
|
||||||
|
fields.append(field)
|
||||||
|
|
||||||
|
schema = CollectionSchema(fields, description=description)
|
||||||
|
collection = Collection(name=collection_name, schema=schema, using=self.alias)
|
||||||
|
|
||||||
|
# 自动为向量字段创建索引 (内网推荐 HNSW)
|
||||||
|
for cfg in fields_config:
|
||||||
|
if cfg['type'] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
|
||||||
|
index_params = {
|
||||||
|
"metric_type": "L2",
|
||||||
|
"index_type": "HNSW",
|
||||||
|
"params": {"M": 16, "efConstruction": 64}
|
||||||
|
}
|
||||||
|
collection.create_index(field_name=cfg['name'], index_params=index_params)
|
||||||
|
|
||||||
|
# print(f"🚀 集合 '{collection_name}' 创建并初始化索引完成。")
|
||||||
|
return collection
|
||||||
|
|
||||||
|
def drop_collection(self, collection_name):
|
||||||
|
"""物理删除集合"""
|
||||||
|
if utility.has_collection(collection_name, using=self.alias):
|
||||||
|
utility.drop_collection(collection_name, using=self.alias)
|
||||||
|
print(f"🗑️ 集合 '{collection_name}' 已从磁盘删除。")
|
||||||
|
|
||||||
|
# --- 内存与分区优化 ---
|
||||||
|
def load_segmented(self, collection_name, partition_names=None):
|
||||||
|
"""分段加载:只加载特定分区到内存以保护内网服务器内存"""
|
||||||
|
col = Collection(collection_name, using=self.alias)
|
||||||
|
col.load(partition_names=partition_names)
|
||||||
|
print(f"🧠 已加载集合数据(分区: {partition_names if partition_names else '全量'})到内存。")
|
||||||
|
|
||||||
|
def release_memory(self, collection_name):
|
||||||
|
"""释放内存:在执行 VBench 推理等高显存任务前调用"""
|
||||||
|
col = Collection(collection_name, using=self.alias)
|
||||||
|
col.release()
|
||||||
|
print(f"♻️ 集合 '{collection_name}' 已释放内存占用。")
|
||||||
|
|
||||||
|
# --- 数据操作 (DML) ---
|
||||||
|
def _validate_and_format(self, col, data_dicts):
|
||||||
|
"""私有工具:校验字典完整性并转换为列式数据"""
|
||||||
|
schema = col.schema
|
||||||
|
# 找出需要填充的字段名(排除自增ID)
|
||||||
|
required_fields = [f.name for f in schema.fields if not f.auto_id]
|
||||||
|
|
||||||
|
# 提取向量字段信息用于维度校验
|
||||||
|
vec_info = {f.name: f.params['dim'] for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR}
|
||||||
|
|
||||||
|
columnar_data = {name: [] for name in required_fields}
|
||||||
|
for i, entry in enumerate(data_dicts):
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in entry:
|
||||||
|
raise ValueError(f"记录 {i} 缺失必填字段: {field}")
|
||||||
|
# 向量维度校验
|
||||||
|
if field in vec_info and len(entry[field]) != vec_info[field]:
|
||||||
|
raise ValueError(f"记录 {i} 向量维度错误: 预期 {vec_info[field]}, 实际 {len(entry[field])}")
|
||||||
|
columnar_data[field].append(entry[field])
|
||||||
|
|
||||||
|
return [columnar_data[f] for f in required_fields]
|
||||||
|
|
||||||
|
def upsert(self, collection_name, data_dicts, partition_name=None):
|
||||||
|
"""通用 Upsert:支持字典列表输入,自动识别主键更新"""
|
||||||
|
pks = [item['id'] for item in data_dicts]
|
||||||
|
self.delete(collection_name, pks)
|
||||||
|
col = Collection(collection_name, using=self.alias)
|
||||||
|
if self.partitionize is None:
|
||||||
|
formatted_data = self._validate_and_format(col, data_dicts)
|
||||||
|
res = col.upsert(formatted_data, partition_name=partition_name)
|
||||||
|
col.flush() # 内网环境强制落盘以防数据丢失
|
||||||
|
return res
|
||||||
|
grouped_data = {}
|
||||||
|
for entry in data_dicts:
|
||||||
|
p_name = str(partition_func(entry))
|
||||||
|
if p_name not in grouped_data:
|
||||||
|
grouped_data[p_name] = []
|
||||||
|
grouped_data[p_name].append(entry)
|
||||||
|
|
||||||
|
for p_name, p_data in grouped_data.items():
|
||||||
|
# 自动维护分区
|
||||||
|
if p_name != "_default" and not col.has_partition(p_name):
|
||||||
|
col.create_partition(p_name)
|
||||||
|
print(f"🚩 自动创建新分区: {p_name}")
|
||||||
|
|
||||||
|
# 格式化数据
|
||||||
|
formatted = self._validate_and_format(col, p_data)
|
||||||
|
|
||||||
|
# 执行 Upsert
|
||||||
|
col.upsert(formatted, partition_name=p_name)
|
||||||
|
col.flush()
|
||||||
|
|
||||||
|
def delete(self, collection_name, pks, partition_name=None):
|
||||||
|
"""删除指定主键记录"""
|
||||||
|
col = Collection(collection_name, using=self.alias)
|
||||||
|
expr = f"id in {pks}" if isinstance(pks, list) else f"id == {pks}"
|
||||||
|
col.delete(expr, partition_name=partition_name)
|
||||||
|
print(f"✂️ 已删除主键为 {pks} 的记录。")
|
||||||
|
|
||||||
|
# --- 高级组合检索 ---
|
||||||
|
def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None):
|
||||||
|
"""
|
||||||
|
组合查询接口:支持向量近似搜索 + 标量过滤 + 分页
|
||||||
|
:param vector: 目标向量,若为 None 则退化为纯标量查询
|
||||||
|
:param expr: 过滤条件,如 "score > 0.8 and lang == 'cmn'"
|
||||||
|
:param offset: 分页偏移量
|
||||||
|
"""
|
||||||
|
offset = (page - 1) * pagerows
|
||||||
|
limit = pagerows
|
||||||
|
col = Collection(collection_name, using=self.alias)
|
||||||
|
# 确保数据已加载
|
||||||
|
if utility.get_query_segment_info(collection_name, using=self.alias) == []:
|
||||||
|
self.load_segmented(collection_name)
|
||||||
|
|
||||||
|
if output_fields is None:
|
||||||
|
output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR]
|
||||||
|
|
||||||
|
if vector is not None:
|
||||||
|
# 自动寻找向量字段名
|
||||||
|
vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR)
|
||||||
|
search_params = {"metric_type": "L2", "params": {"ef": 64}, "offset": offset}
|
||||||
|
|
||||||
|
return col.search(
|
||||||
|
data=[vector],
|
||||||
|
anns_field=vec_field,
|
||||||
|
param=search_params,
|
||||||
|
limit=limit,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=output_fields
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 纯标量查询
|
||||||
|
return col.query(expr=expr, limit=limit, offset=offset, output_fields=output_fields)
|
||||||
Loading…
x
Reference in New Issue
Block a user