From 5cdbef818ceb5cb1b48e7fd217d31666611e9424 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Fri, 6 Feb 2026 13:39:45 +0800 Subject: [PATCH] bugfix --- vdb/milvus.py | 316 +++++++++++++++++++++++++------------------------- 1 file changed, 158 insertions(+), 158 deletions(-) diff --git a/vdb/milvus.py b/vdb/milvus.py index 028a32f..ea0b784 100644 --- a/vdb/milvus.py +++ b/vdb/milvus.py @@ -3,13 +3,13 @@ from appPublic.jsonConfig import getConfig from appPublic.log import debug, exception, error from pymilvus import ( - connections, FieldSchema, CollectionSchema, - DataType, Collection, utility, Partition + connections, FieldSchema, CollectionSchema, + DataType, Collection, utility, Partition ) class MilvusManager(BaseVDB): - _instance = None - _lock = Lock() + _instance = None + _lock = Lock() dbtypes = { "str": DataType.VARCHAR, "int": DataType.INT32, @@ -19,130 +19,130 @@ class MilvusManager(BaseVDB): "bvector": DataType.BINARY_VECTOR, "json": DataType.JSON } - def __new__(cls): - with cls._lock: - if cls._instance is None: - cls._instance = super(MilvusDBConnection, cls).__new__(cls) - cls._instance._initialized = False - return cls._instance + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super(MilvusDBConnection, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance - def __init__(self, partitionize=None): + def __init__(self, partitionize=None): self.partitionize = partitionize - if self._initialized: - return - try: - config = getConfig() - self.db_path = config['milvus_db'] - debug(f"dbpath: {self.db_path}") - except KeyError as e: - error(f"配置文件缺少必要字段: {str(e)}") - raise RuntimeError(f"配置文件缺少必要字段: {str(e)}") - self._initialize_connection() - self._initialized = True - info(f"MilvusDBConnection initialized with db_path: {self.db_path}") + if self._initialized: + return + try: + config = getConfig() + self.db_path = config['milvus_db'] + debug(f"dbpath: {self.db_path}") + except KeyError as e: + error(f"配置文件缺少必要字段: {str(e)}") + raise RuntimeError(f"配置文件缺少必要字段: {str(e)}") + self._initialize_connection() + self._initialized = True + info(f"MilvusDBConnection initialized with db_path: {self.db_path}") - def _initialize_connection(self): - """初始化 Milvus 连接,确保单一连接""" - try: - db_dir = os.path.dirname(self.db_path) - debug(f"db_dir: {db_dir}") - if not os.path.exists(db_dir): - os.makedirs(db_dir, exist_ok=True) - debug(f"创建 Milvus 目录: {db_dir}") - if not os.access(db_dir, os.W_OK): - raise RuntimeError(f"Milvus 目录 {db_dir} 不可写") - debug(f"不可写") - if not connections.has_connection(self.db_path): - connections.connect(self.db_path, uri=self.db_path) - debug(f"已连接到 Milvus Lite,路径: {self.db_path}") - else: - debug("已存在 Milvus 连接,跳过重复连接") - except Exception as e: - error(f"连接 Milvus 失败: {str(e)}") - raise RuntimeError(f"连接 Milvus 失败: {str(e)}") + def _initialize_connection(self): + """初始化 Milvus 连接,确保单一连接""" + try: + db_dir = os.path.dirname(self.db_path) + debug(f"db_dir: {db_dir}") + if not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + debug(f"创建 Milvus 目录: {db_dir}") + if not os.access(db_dir, os.W_OK): + raise RuntimeError(f"Milvus 目录 {db_dir} 不可写") + debug(f"不可写") + if not connections.has_connection(self.db_path): + connections.connect(self.db_path, uri=self.db_path) + debug(f"已连接到 Milvus Lite,路径: {self.db_path}") + else: + debug("已存在 Milvus 连接,跳过重复连接") + except Exception as e: + error(f"连接 Milvus 失败: {str(e)}") + raise RuntimeError(f"连接 Milvus 失败: {str(e)}") - # --- 集合管理 --- - def create_collection(self, collection_name, fields_config, description=""): - """ - 打开或创建集合。如果已存在则返回对象;不存在则根据 config 创建。 - :param fields_config: 格式: [{"name": "id", "type": DataType.INT64, "is_primary": True}, ...] - """ - if utility.has_collection(collection_name, using=self.alias): - # print(f"📦 集合 '{collection_name}' 已存在,直接加载。") - return Collection(collection_name, using=self.alias) + # --- 集合管理 --- + def create_collection(self, collection_name, fields_config, description=""): + """ + 打开或创建集合。如果已存在则返回对象;不存在则根据 config 创建。 + :param fields_config: 格式: [{"name": "id", "type": DataType.INT64, "is_primary": True}, ...] + """ + if utility.has_collection(collection_name, using=self.alias): + # print(f"📦 集合 '{collection_name}' 已存在,直接加载。") + return Collection(collection_name, using=self.alias) - fields = [] - for cfg in fields_config: - field = FieldSchema( - name=cfg['name'], - dtype=self.get_db_type(cfg['type']), - is_primary=cfg.get('is_primary', False), - auto_id=cfg.get('auto_id', False), - dim=cfg.get('dim'), - max_length=cfg.get('max_length') - ) - fields.append(field) + fields = [] + for cfg in fields_config: + field = FieldSchema( + name=cfg['name'], + dtype=self.get_db_type(cfg['type']), + is_primary=cfg.get('is_primary', False), + auto_id=cfg.get('auto_id', False), + dim=cfg.get('dim'), + max_length=cfg.get('max_length') + ) + fields.append(field) - schema = CollectionSchema(fields, description=description) - collection = Collection(name=collection_name, schema=schema, using=self.alias) - - # 自动为向量字段创建索引 (内网推荐 HNSW) - for cfg in fields_config: + schema = CollectionSchema(fields, description=description) + collection = Collection(name=collection_name, schema=schema, using=self.alias) + + # 自动为向量字段创建索引 (内网推荐 HNSW) + for cfg in fields_config: dbtype = self.get_db_type(cfg['type']) - if dbtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: - index_params = { - "metric_type": "L2", - "index_type": "HNSW", - "params": {"M": 16, "efConstruction": 64} - } - collection.create_index(field_name=cfg['name'], index_params=index_params) - - # print(f"🚀 集合 '{collection_name}' 创建并初始化索引完成。") - return collection + if dbtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: + index_params = { + "metric_type": "L2", + "index_type": "HNSW", + "params": {"M": 16, "efConstruction": 64} + } + collection.create_index(field_name=cfg['name'], index_params=index_params) + + # print(f"🚀 集合 '{collection_name}' 创建并初始化索引完成。") + return collection - def drop_collection(self, collection_name): - """物理删除集合""" - if utility.has_collection(collection_name, using=self.alias): - utility.drop_collection(collection_name, using=self.alias) - print(f"🗑️ 集合 '{collection_name}' 已从磁盘删除。") + def drop_collection(self, collection_name): + """物理删除集合""" + if utility.has_collection(collection_name, using=self.alias): + utility.drop_collection(collection_name, using=self.alias) + print(f"🗑️ 集合 '{collection_name}' 已从磁盘删除。") - # --- 内存与分区优化 --- - def load_segmented(self, collection_name, partition_names=None): - """分段加载:只加载特定分区到内存以保护内网服务器内存""" - col = Collection(collection_name, using=self.alias) - col.load(partition_names=partition_names) - print(f"🧠 已加载集合数据(分区: {partition_names if partition_names else '全量'})到内存。") + # --- 内存与分区优化 --- + def load_segmented(self, collection_name, partition_names=None): + """分段加载:只加载特定分区到内存以保护内网服务器内存""" + col = Collection(collection_name, using=self.alias) + col.load(partition_names=partition_names) + print(f"🧠 已加载集合数据(分区: {partition_names if partition_names else '全量'})到内存。") - def release_memory(self, collection_name): - """释放内存:在执行 VBench 推理等高显存任务前调用""" - col = Collection(collection_name, using=self.alias) - col.release() - print(f"♻️ 集合 '{collection_name}' 已释放内存占用。") + def release_memory(self, collection_name): + """释放内存:在执行 VBench 推理等高显存任务前调用""" + col = Collection(collection_name, using=self.alias) + col.release() + print(f"♻️ 集合 '{collection_name}' 已释放内存占用。") - # --- 数据操作 (DML) --- - def _validate_and_format(self, col, data_dicts): - """私有工具:校验字典完整性并转换为列式数据""" - schema = col.schema - # 找出需要填充的字段名(排除自增ID) - required_fields = [f.name for f in schema.fields if not f.auto_id] - - # 提取向量字段信息用于维度校验 - vec_info = {f.name: f.params['dim'] for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR} + # --- 数据操作 (DML) --- + def _validate_and_format(self, col, data_dicts): + """私有工具:校验字典完整性并转换为列式数据""" + schema = col.schema + # 找出需要填充的字段名(排除自增ID) + required_fields = [f.name for f in schema.fields if not f.auto_id] + + # 提取向量字段信息用于维度校验 + vec_info = {f.name: f.params['dim'] for f in schema.fields if f.dtype == DataType.FLOAT_VECTOR} - columnar_data = {name: [] for name in required_fields} - for i, entry in enumerate(data_dicts): - for field in required_fields: - if field not in entry: - raise ValueError(f"记录 {i} 缺失必填字段: {field}") - # 向量维度校验 - if field in vec_info and len(entry[field]) != vec_info[field]: - raise ValueError(f"记录 {i} 向量维度错误: 预期 {vec_info[field]}, 实际 {len(entry[field])}") - columnar_data[field].append(entry[field]) - - return [columnar_data[f] for f in required_fields] + columnar_data = {name: [] for name in required_fields} + for i, entry in enumerate(data_dicts): + for field in required_fields: + if field not in entry: + raise ValueError(f"记录 {i} 缺失必填字段: {field}") + # 向量维度校验 + if field in vec_info and len(entry[field]) != vec_info[field]: + raise ValueError(f"记录 {i} 向量维度错误: 预期 {vec_info[field]}, 实际 {len(entry[field])}") + columnar_data[field].append(entry[field]) + + return [columnar_data[f] for f in required_fields] - def upsert(self, collection_name, data_dicts): - """通用 Upsert:支持字典列表输入,自动识别主键更新""" + def upsert(self, collection_name, data_dicts): + """通用 Upsert:支持字典列表输入,自动识别主键更新""" pks = [item['id'] for item in data_dicts] self.delete(collection_name, pks) col = Collection(collection_name, using=self.alias) @@ -152,56 +152,56 @@ class MilvusManager(BaseVDB): col.flush() # 内网环境强制落盘以防数据丢失 return res grouped_data = {} - for entry in data_dicts: + for entry in data_dicts: p_name = str(partition_func(entry)) - if p_name not in grouped_data: - grouped_data[p_name] = [] - grouped_data[p_name].append(entry) + if p_name not in grouped_data: + grouped_data[p_name] = [] + grouped_data[p_name].append(entry) for p_name, p_data in grouped_data.items(): - # 自动维护分区 - if p_name != "_default" and not col.has_partition(p_name): - col.create_partition(p_name) - print(f"🚩 自动创建新分区: {p_name}") - - # 格式化数据 - formatted = self._validate_and_format(col, p_data) - - # 执行 Upsert - col.upsert(formatted, partition_name=p_name) + # 自动维护分区 + if p_name != "_default" and not col.has_partition(p_name): + col.create_partition(p_name) + print(f"🚩 自动创建新分区: {p_name}") + + # 格式化数据 + formatted = self._validate_and_format(col, p_data) + + # 执行 Upsert + col.upsert(formatted, partition_name=p_name) col.flush() - def delete(self, collection_name, pks, partition_name=None): - """删除指定主键记录""" - col = Collection(collection_name, using=self.alias) - expr = f"id in {pks}" if isinstance(pks, list) else f"id == {pks}" - col.delete(expr, partition_name=partition_name) - print(f"✂️ 已删除主键为 {pks} 的记录。") + def delete(self, collection_name, pks, partition_name=None): + """删除指定主键记录""" + col = Collection(collection_name, using=self.alias) + expr = f"id in {pks}" if isinstance(pks, list) else f"id == {pks}" + col.delete(expr, partition_name=partition_name) + print(f"✂️ 已删除主键为 {pks} 的记录。") - # --- 高级组合检索 --- - def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None): - """ - 组合查询接口:支持向量近似搜索 + 标量过滤 + 分页 - :param vector: 目标向量,若为 None 则退化为纯标量查询 - :param expr: 过滤条件,如 "score > 0.8 and lang == 'cmn'" - :param offset: 分页偏移量 - """ + # --- 高级组合检索 --- + def query(self, collection_name, vector=None, expr=None, pagerows=80, page=1, output_fields=None): + """ + 组合查询接口:支持向量近似搜索 + 标量过滤 + 分页 + :param vector: 目标向量,若为 None 则退化为纯标量查询 + :param expr: 过滤条件,如 "score > 0.8 and lang == 'cmn'" + :param offset: 分页偏移量 + """ offset = (page - 1) * pagerows limit = pagerows - col = Collection(collection_name, using=self.alias) - # 确保数据已加载 - if utility.get_query_segment_info(collection_name, using=self.alias) == []: - self.load_segmented(collection_name) + col = Collection(collection_name, using=self.alias) + # 确保数据已加载 + if utility.get_query_segment_info(collection_name, using=self.alias) == []: + self.load_segmented(collection_name) - if output_fields is None: - output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR] + if output_fields is None: + output_fields = [f.name for f in col.schema.fields if f.dtype != DataType.FLOAT_VECTOR] - if vector is not None: - # 自动寻找向量字段名 - vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR) - search_params = {"metric_type": "L2", "params": {"ef": 64}, "offset": offset} - - return { + if vector is not None: + # 自动寻找向量字段名 + vec_field = next(f.name for f in col.schema.fields if f.dtype == DataType.FLOAT_VECTOR) + search_params = {"metric_type": "L2", "params": {"ef": 64}, "offset": offset} + + return { "total": -1, "page": page, "pagerows": pagerows, @@ -215,9 +215,9 @@ class MilvusManager(BaseVDB): output_fields=output_fields ) } - else: - # 纯标量查询 - return { + else: + # 纯标量查询 + return { "total": -1, "page": page, "pagerows": pagerows,