This commit is contained in:
yumoqing 2026-02-06 13:39:45 +08:00
parent a94965d7cb
commit 5cdbef818c

View File

@ -3,13 +3,13 @@ from appPublic.jsonConfig import getConfig
from appPublic.log import debug, exception, error from appPublic.log import debug, exception, error
from pymilvus import ( from pymilvus import (
connections, FieldSchema, CollectionSchema, connections, FieldSchema, CollectionSchema,
DataType, Collection, utility, Partition DataType, Collection, utility, Partition
) )
class MilvusManager(BaseVDB): class MilvusManager(BaseVDB):
_instance = None _instance = None
_lock = Lock() _lock = Lock()
dbtypes = { dbtypes = {
"str": DataType.VARCHAR, "str": DataType.VARCHAR,
"int": DataType.INT32, "int": DataType.INT32,
@ -19,130 +19,130 @@ class MilvusManager(BaseVDB):
"bvector": DataType.BINARY_VECTOR, "bvector": DataType.BINARY_VECTOR,
"json": DataType.JSON "json": DataType.JSON
} }
def __new__(cls): def __new__(cls):
with cls._lock: with cls._lock:
if cls._instance is None: if cls._instance is None:
cls._instance = super(MilvusDBConnection, cls).__new__(cls) cls._instance = super(MilvusDBConnection, cls).__new__(cls)
cls._instance._initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
def __init__(self, partitionize=None): def __init__(self, partitionize=None):
self.partitionize = partitionize self.partitionize = partitionize
if self._initialized: if self._initialized:
return return
try: try:
config = getConfig() config = getConfig()
self.db_path = config['milvus_db'] self.db_path = config['milvus_db']
debug(f"dbpath: {self.db_path}") debug(f"dbpath: {self.db_path}")
except KeyError as e: except KeyError as e:
error(f"配置文件缺少必要字段: {str(e)}") error(f"配置文件缺少必要字段: {str(e)}")
raise RuntimeError(f"配置文件缺少必要字段: {str(e)}") raise RuntimeError(f"配置文件缺少必要字段: {str(e)}")
self._initialize_connection() self._initialize_connection()
self._initialized = True self._initialized = True
info(f"MilvusDBConnection initialized with db_path: {self.db_path}") info(f"MilvusDBConnection initialized with db_path: {self.db_path}")
def _initialize_connection(self): def _initialize_connection(self):
"""初始化 Milvus 连接,确保单一连接""" """初始化 Milvus 连接,确保单一连接"""
try: try:
db_dir = os.path.dirname(self.db_path) db_dir = os.path.dirname(self.db_path)
debug(f"db_dir: {db_dir}") debug(f"db_dir: {db_dir}")
if not os.path.exists(db_dir): if not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True) os.makedirs(db_dir, exist_ok=True)
debug(f"创建 Milvus 目录: {db_dir}") debug(f"创建 Milvus 目录: {db_dir}")
if not os.access(db_dir, os.W_OK): if not os.access(db_dir, os.W_OK):
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写") raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
debug(f"不可写") debug(f"不可写")
if not connections.has_connection(self.db_path): if not connections.has_connection(self.db_path):
connections.connect(self.db_path, uri=self.db_path) connections.connect(self.db_path, uri=self.db_path)
debug(f"已连接到 Milvus Lite路径: {self.db_path}") debug(f"已连接到 Milvus Lite路径: {self.db_path}")
else: else:
debug("已存在 Milvus 连接,跳过重复连接") debug("已存在 Milvus 连接,跳过重复连接")
except Exception as e: except Exception as e:
error(f"连接 Milvus 失败: {str(e)}") error(f"连接 Milvus 失败: {str(e)}")
raise RuntimeError(f"连接 Milvus 失败: {str(e)}") raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
# --- 集合管理 --- # --- 集合管理 ---
def create_collection(self, collection_name, fields_config, description=""): def create_collection(self, collection_name, fields_config, description=""):
""" """
打开或创建集合如果已存在则返回对象不存在则根据 config 创建 打开或创建集合如果已存在则返回对象不存在则根据 config 创建
:param fields_config: 格式: [{"name": "id", "type": DataType.INT64, "is_primary": True}, ...] :param fields_config: 格式: [{"name": "id", "type": DataType.INT64, "is_primary": True}, ...]
""" """
if utility.has_collection(collection_name, using=self.alias): if utility.has_collection(collection_name, using=self.alias):
# print(f"📦 集合 '{collection_name}' 已存在,直接加载。") # print(f"📦 集合 '{collection_name}' 已存在,直接加载。")
return Collection(collection_name, using=self.alias) return Collection(collection_name, using=self.alias)
fields = [] fields = []
for cfg in fields_config: for cfg in fields_config:
field = FieldSchema( field = FieldSchema(
name=cfg['name'], name=cfg['name'],
dtype=self.get_db_type(cfg['type']), dtype=self.get_db_type(cfg['type']),
is_primary=cfg.get('is_primary', False), is_primary=cfg.get('is_primary', False),
auto_id=cfg.get('auto_id', False), auto_id=cfg.get('auto_id', False),
dim=cfg.get('dim'), dim=cfg.get('dim'),
max_length=cfg.get('max_length') max_length=cfg.get('max_length')
) )
fields.append(field) fields.append(field)
schema = CollectionSchema(fields, description=description) schema = CollectionSchema(fields, description=description)
collection = Collection(name=collection_name, schema=schema, using=self.alias) collection = Collection(name=collection_name, schema=schema, using=self.alias)
# 自动为向量字段创建索引 (内网推荐 HNSW) # 自动为向量字段创建索引 (内网推荐 HNSW)
for cfg in fields_config: for cfg in fields_config:
dbtype = self.get_db_type(cfg['type']) dbtype = self.get_db_type(cfg['type'])
if dbtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: if dbtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
index_params = { index_params = {
"metric_type": "L2", "metric_type": "L2",
"index_type": "HNSW", "index_type": "HNSW",
"params": {"M": 16, "efConstruction": 64} "params": {"M": 16, "efConstruction": 64}
} }
collection.create_index(field_name=cfg['name'], index_params=index_params) collection.create_index(field_name=cfg['name'], index_params=index_params)
# print(f"🚀 集合 '{collection_name}' 创建并初始化索引完成。") # print(f"🚀 集合 '{collection_name}' 创建并初始化索引完成。")
return collection return collection
def drop_collection(self, collection_name): def drop_collection(self, collection_name):
"""物理删除集合""" """物理删除集合"""
if utility.has_collection(collection_name, using=self.alias): if utility.has_collection(collection_name, using=self.alias):
utility.drop_collection(collection_name, using=self.alias) utility.drop_collection(collection_name, using=self.alias)
print(f"🗑️ 集合 '{collection_name}' 已从磁盘删除。") print(f"🗑️ 集合 '{collection_name}' 已从磁盘删除。")
# --- 内存与分区优化 --- # --- 内存与分区优化 ---
def load_segmented(self, collection_name, partition_names=None): def load_segmented(self, collection_name, partition_names=None):
"""分段加载:只加载特定分区到内存以保护内网服务器内存""" """分段加载:只加载特定分区到内存以保护内网服务器内存"""
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
col.load(partition_names=partition_names) col.load(partition_names=partition_names)
print(f"🧠 已加载集合数据(分区: {partition_names if partition_names else '全量'})到内存。") print(f"🧠 已加载集合数据(分区: {partition_names if partition_names else '全量'})到内存。")
def release_memory(self, collection_name): def release_memory(self, collection_name):
"""释放内存:在执行 VBench 推理等高显存任务前调用""" """释放内存:在执行 VBench 推理等高显存任务前调用"""
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
col.release() col.release()
print(f"♻️ 集合 '{collection_name}' 已释放内存占用。") print(f"♻️ 集合 '{collection_name}' 已释放内存占用。")
# --- 数据操作 (DML) --- # --- 数据操作 (DML) ---
def _validate_and_format(self, col, data_dicts): def _validate_and_format(self, col, data_dicts):
"""私有工具:校验字典完整性并转换为列式数据""" """私有工具:校验字典完整性并转换为列式数据"""
schema = col.schema schema = col.schema
# 找出需要填充的字段名排除自增ID # 找出需要填充的字段名排除自增ID
required_fields = [f.name for f in schema.fields if not f.auto_id] required_fields = [f.name for f in schema.fields if not f.auto_id]
# 提取向量字段信息用于维度校验 # 提取向量字段信息用于维度校验
vec_info = {f.name: f.params['dim'] for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR} vec_info = {f.name: f.params['dim'] for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR}
columnar_data = {name: [] for name in required_fields} columnar_data = {name: [] for name in required_fields}
for i, entry in enumerate(data_dicts): for i, entry in enumerate(data_dicts):
for field in required_fields: for field in required_fields:
if field not in entry: if field not in entry:
raise ValueError(f"记录 {i} 缺失必填字段: {field}") raise ValueError(f"记录 {i} 缺失必填字段: {field}")
# 向量维度校验 # 向量维度校验
if field in vec_info and len(entry[field]) != vec_info[field]: if field in vec_info and len(entry[field]) != vec_info[field]:
raise ValueError(f"记录 {i} 向量维度错误: 预期 {vec_info[field]}, 实际 {len(entry[field])}") raise ValueError(f"记录 {i} 向量维度错误: 预期 {vec_info[field]}, 实际 {len(entry[field])}")
columnar_data[field].append(entry[field]) columnar_data[field].append(entry[field])
return [columnar_data[f] for f in required_fields] return [columnar_data[f] for f in required_fields]
def upsert(self, collection_name, data_dicts): def upsert(self, collection_name, data_dicts):
"""通用 Upsert支持字典列表输入自动识别主键更新""" """通用 Upsert支持字典列表输入自动识别主键更新"""
pks = [item['id'] for item in data_dicts] pks = [item['id'] for item in data_dicts]
self.delete(collection_name, pks) self.delete(collection_name, pks)
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
@ -152,56 +152,56 @@ class MilvusManager(BaseVDB):
col.flush() # 内网环境强制落盘以防数据丢失 col.flush() # 内网环境强制落盘以防数据丢失
return res return res
grouped_data = {} grouped_data = {}
for entry in data_dicts: for entry in data_dicts:
p_name = str(partition_func(entry)) p_name = str(partition_func(entry))
if p_name not in grouped_data: if p_name not in grouped_data:
grouped_data[p_name] = [] grouped_data[p_name] = []
grouped_data[p_name].append(entry) grouped_data[p_name].append(entry)
for p_name, p_data in grouped_data.items(): for p_name, p_data in grouped_data.items():
# 自动维护分区 # 自动维护分区
if p_name != "_default" and not col.has_partition(p_name): if p_name != "_default" and not col.has_partition(p_name):
col.create_partition(p_name) col.create_partition(p_name)
print(f"🚩 自动创建新分区: {p_name}") print(f"🚩 自动创建新分区: {p_name}")
# 格式化数据 # 格式化数据
formatted = self._validate_and_format(col, p_data) formatted = self._validate_and_format(col, p_data)
# 执行 Upsert # 执行 Upsert
col.upsert(formatted, partition_name=p_name) col.upsert(formatted, partition_name=p_name)
col.flush() col.flush()
def delete(self, collection_name, pks, partition_name=None): def delete(self, collection_name, pks, partition_name=None):
"""删除指定主键记录""" """删除指定主键记录"""
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
expr = f"id in {pks}" if isinstance(pks, list) else f"id == {pks}" expr = f"id in {pks}" if isinstance(pks, list) else f"id == {pks}"
col.delete(expr, partition_name=partition_name) col.delete(expr, partition_name=partition_name)
print(f"✂️ 已删除主键为 {pks} 的记录。") print(f"✂️ 已删除主键为 {pks} 的记录。")
# --- 高级组合检索 --- # --- 高级组合检索 ---
def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None): def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None):
""" """
组合查询接口支持向量近似搜索 + 标量过滤 + 分页 组合查询接口支持向量近似搜索 + 标量过滤 + 分页
:param vector: 目标向量若为 None 则退化为纯标量查询 :param vector: 目标向量若为 None 则退化为纯标量查询
:param expr: 过滤条件 "score > 0.8 and lang == 'cmn'" :param expr: 过滤条件 "score > 0.8 and lang == 'cmn'"
:param offset: 分页偏移量 :param offset: 分页偏移量
""" """
offset = (page - 1) * pagerows offset = (page - 1) * pagerows
limit = pagerows limit = pagerows
col = Collection(collection_name, using=self.alias) col = Collection(collection_name, using=self.alias)
# 确保数据已加载 # 确保数据已加载
if utility.get_query_segment_info(collection_name, using=self.alias) == []: if utility.get_query_segment_info(collection_name, using=self.alias) == []:
self.load_segmented(collection_name) self.load_segmented(collection_name)
if output_fields is None: if output_fields is None:
output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR] output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR]
if vector is not None: if vector is not None:
# 自动寻找向量字段名 # 自动寻找向量字段名
vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR) vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR)
search_params = {"metric_type": "L2", "params": {"ef": 64}, "offset": offset} search_params = {"metric_type": "L2", "params": {"ef": 64}, "offset": offset}
return { return {
"total": -1, "total": -1,
"page": page, "page": page,
"pagerows": pagerows, "pagerows": pagerows,
@ -215,9 +215,9 @@ class MilvusManager(BaseVDB):
output_fields=output_fields output_fields=output_fields
) )
} }
else: else:
# 纯标量查询 # 纯标量查询
return { return {
"total": -1, "total": -1,
"page": page, "page": page,
"pagerows": pagerows, "pagerows": pagerows,