diff --git a/llmengine/base_m2m.py b/llmengine/base_m2m.py index a1f5654..c4f5ff8 100644 --- a/llmengine/base_m2m.py +++ b/llmengine/base_m2m.py @@ -22,5 +22,5 @@ class BaseM2M: def __init__(self, model_id, **kw): self.model_id = model_id - def m2m(self, texts: str, src_lang: str, tgt_lang: str) -> str: + def translate(self, texts: str, src_lang: str, tgt_lang: str) -> str: raise NotImplementedError \ No newline at end of file diff --git a/llmengine/fanyi_m2m.py b/llmengine/fanyi_m2m.py index 3873b09..0ce80e3 100644 --- a/llmengine/fanyi_m2m.py +++ b/llmengine/fanyi_m2m.py @@ -15,7 +15,7 @@ class M2M100Translator(BaseM2M): self.model.eval() self.model_name = model_id.split('/')[-1] - def m2m(self, text: str, src_lang: str, tgt_lang: str) -> str: + def translate(self, text: str, src_lang: str, tgt_lang: str) -> str: """翻译一段话""" self.tokenizer.src_lang = src_lang encoded = self.tokenizer( diff --git a/llmengine/m2m.py b/llmengine/m2m.py index 157105d..d776066 100644 --- a/llmengine/m2m.py +++ b/llmengine/m2m.py @@ -13,7 +13,7 @@ from .base_m2m import get_llm_class helptext = """M2M100 翻译 API: -POST /v1/m2m +POST /v1/translate Headers: Content-Type: application/json @@ -39,7 +39,7 @@ Response: def init(): rf = RegisterFunction() - rf.register('m2m', m2m) + rf.register('translate', translate) rf.register('docs', docs) @@ -47,7 +47,7 @@ async def docs(request, params_kw, *params, **kw): return helptext -async def m2m(request, params_kw, *params, **kw): +async def translate(request, params_kw, *params, **kw): debug(f'{params_kw=}') se = ServerEnv() engine = se.engine @@ -60,7 +60,7 @@ async def m2m(request, params_kw, *params, **kw): if not text or not isinstance(text, str): raise Exception("`text` must be a non-empty string") - f = awaitify(engine.m2m) + f = awaitify(engine.translate) translation = await f(text, src_lang, tgt_lang) ret = { @@ -88,6 +88,19 @@ def main(): se = ServerEnv() se.engine = Klass(args.model_path) + if torch.cuda.is_available(): + logical_id = torch.cuda.current_device() + physical_id = int(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]) + gpu_name = torch.cuda.get_device_name(logical_id) + mem_used = torch.cuda.memory_allocated(logical_id) / 1024 ** 3 + mem_total = torch.cuda.get_device_properties(logical_id).total_memory / 1024 ** 3 + + debug(f"\n我正在使用 物理 GPU {physical_id} → 逻辑 GPU {logical_id}") + debug(f"显卡型号: {gpu_name}") + debug(f"显存占用: {mem_used:.1f} GB / {mem_total:.1f} GB\n") + else: + debug("\n我在 CPU 上跑\n") + debug(f"Starting M2M100 service on port {args.port}") webserver(init, args.workdir, args.port) diff --git a/llmengine/milvus_db.py b/llmengine/milvus_db.py index 6ab0436..04c132d 100755 --- a/llmengine/milvus_db.py +++ b/llmengine/milvus_db.py @@ -9,6 +9,7 @@ from datetime import datetime from llmengine.base_db import connection_register, BaseDBConnection import time import traceback +import hashlib class MilvusDBConnection(BaseDBConnection): _instance = None @@ -62,6 +63,7 @@ class MilvusDBConnection(BaseDBConnection): if not params: params = {} db_type = params.get("db_type", "") + fields = params.get("fields", None) # 获取 fields 参数 collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" if db_type and "_" in db_type: return {"status": "error", "message": "db_type 不能包含下划线", "collection_name": collection_name, @@ -72,10 +74,8 @@ class MilvusDBConnection(BaseDBConnection): if action == "initialize": return {"status": "success", "message": f"Milvus 连接已初始化,路径: {self.db_path}"} - elif action == "get_params": - return {"status": "success", "params": {"uri": self.db_path}} elif action == "create_collection": - return await self._create_collection(db_type) + return await self._create_collection(db_type, fields) elif action == "delete_collection": return await self._delete_collection(db_type) elif action == "insert_document": @@ -86,14 +86,13 @@ class MilvusDBConnection(BaseDBConnection): file_path = params.get("file_path", "") knowledge_base_id = params.get("knowledge_base_id", "") document_id = params.get("document_id", "") - db_type = params.get("db_type", "") - if not userid or not file_path or not knowledge_base_id or not document_id: - return {"status": "error", "message": "userid、file_path document_id和 knowledge_base_id 不能为空1", + if not userid or not knowledge_base_id or not document_id: + return {"status": "error", "message": "userid、document_id 和 knowledge_base_id 不能为空", "collection_name": collection_name, "document_id": "", "status_code": 400} if len(userid) > 100 or len(file_path) > 255 or len(knowledge_base_id) > 100: return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制", "collection_name": collection_name, "document_id": "", "status_code": 400} - return await self._delete_document(userid, file_path, knowledge_base_id, document_id, db_type) + return await self._delete_document(userid, knowledge_base_id, document_id, file_path, db_type) elif action == "delete_knowledge_base": userid = params.get("userid", "") knowledge_base_id = params.get("knowledge_base_id", "") @@ -110,9 +109,8 @@ class MilvusDBConnection(BaseDBConnection): knowledge_base_ids = params.get("knowledge_base_ids", []) limit = params.get("limit", 5) offset = params.get("offset", 0) - db_type = params.get("db_type", "") - if not query_vector or not userid or not knowledge_base_ids: - return {"status": "error", "message": "query_vector、userid 或 knowledge_base_ids 不能为空", + if not query_vector or not userid: + return {"status": "error", "message": "query_vector或userid不能为空", "collection_name": collection_name, "document_id": "", "status_code": 400} if limit < 1 or limit > 16384: return {"status": "error", "message": "limit 必须在 1 到 16384 之间", @@ -139,9 +137,10 @@ class MilvusDBConnection(BaseDBConnection): "status_code": 400 } - async def _create_collection(self, db_type: str = "") -> Dict: - """创建 Milvus 集合""" + async def _create_collection(self, db_type: str = "", fields: List[Any] = None) -> Dict: + """创建 Milvus 集合,支持动态 fields(JSON 或 FieldSchema)""" try: + # 确定集合名称 collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" if len(collection_name) > 255: raise ValueError(f"集合名称 {collection_name} 超过 255 个字符") @@ -149,90 +148,151 @@ class MilvusDBConnection(BaseDBConnection): raise ValueError("db_type 不能包含下划线") if db_type and len(db_type) > 100: raise ValueError("db_type 的长度应小于 100") - debug(f"集合名称: {collection_name}") - - fields = [ - FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True), - FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100), - FieldSchema(name="knowledge_base_id", dtype=DataType.VARCHAR, max_length=100), - FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36), - FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), - FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024), - FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255), - FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024), - FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64), - FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64), - ] - schema = CollectionSchema( - fields=fields, - description="统一数据集合,包含用户ID、知识库ID、document_id 和元数据字段", - auto_id=True, - primary_field="pk", - ) + debug(f"初始集合名称: {collection_name}") + # 检查集合是否已存在 if utility.has_collection(collection_name): - try: - collection = Collection(collection_name) - existing_schema = collection.schema - expected_fields = {f.name for f in fields} - actual_fields = {f.name for f in existing_schema.fields} - vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None) - - schema_compatible = False - if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR: - dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None - schema_compatible = dim == 1024 - debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, " - f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else '无'}, " - f"dim={dim if dim is not None else '未定义'}") - if not schema_compatible: - debug(f"集合 {collection_name} 的 schema 不兼容,原因: " - f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or '无'}, " - f"vector_field: {vector_field is not None}, " - f"dtype: {vector_field.dtype if vector_field else '无'}, " - f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}") - utility.drop_collection(collection_name) - else: - collection.load() - debug(f"集合 {collection_name} 已存在并加载成功") - return { - "status": "success", - "collection_name": collection_name, - "message": f"集合 {collection_name} 已存在" - } - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "message": str(e) - } - - try: - collection = Collection(collection_name, schema) - collection.create_index( - field_name="vector", - index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"} - ) - for field in ["userid", "knowledge_base_id", "document_id", "filename", "file_path", "upload_time", "file_type"]: - collection.create_index( - field_name=field, - index_params={"index_type": "INVERTED"} - ) + collection = Collection(collection_name) collection.load() - debug(f"成功创建并加载集合: {collection_name}") + debug(f"集合 {collection_name} 已存在,直接加载") return { "status": "success", "collection_name": collection_name, - "message": f"集合 {collection_name} 创建成功" - } - except Exception as e: - error(f"创建集合 {collection_name} 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "message": str(e) + "message": f"集合 {collection_name} 已存在,直接使用现有 schema" } + + # 如果集合不存在,处理 fields + if db_type == "": + # db_type 为空,使用默认 fields + fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True), + FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100), + FieldSchema(name="knowledge_base_id", dtype=DataType.VARCHAR, max_length=100), + FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36), + FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), + FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024), + FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255), + FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024), + FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64), + FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64), + ] + debug("db_type 为空,使用默认 fields") + else: + # db_type 不为空,集合不存在时必须提供 fields + if fields is None: + debug(f"集合 {collection_name} 不存在,且未提供 fields") + raise ValueError(f"集合 {collection_name} 不存在,必须提供 fields 参数") + + # 定义 dtype 映射 + dtype_map = { + "VARCHAR": DataType.VARCHAR, + "FLOAT_VECTOR": DataType.FLOAT_VECTOR, + "BINARY_VECTOR": DataType.BINARY_VECTOR, + "INT64": DataType.INT64, + "INT32": DataType.INT32, + "FLOAT": DataType.FLOAT, + "DOUBLE": DataType.DOUBLE, + "BOOL": DataType.BOOL, + "JSON": DataType.JSON + } + + # 如果 fields 是 JSON 格式,转换为 FieldSchema 对象 + if all(isinstance(f, dict) for f in fields): + try: + converted_fields = [] + for f in fields: + if f["dtype"] not in dtype_map: + debug(f"无效的 dtype: {f['dtype']}") + raise ValueError(f"无效的 dtype: {f['dtype']}") + dtype = dtype_map[f["dtype"]] + params = { + "name": f["name"], + "dtype": dtype, + "is_primary": f.get("is_primary", False), + "auto_id": f.get("auto_id", False) + } + if dtype == DataType.VARCHAR: + if "max_length" not in f or not isinstance(f["max_length"], int) or f["max_length"] <= 0: + debug(f"VARCHAR 字段 {f['name']} 必须指定有效的 max_length") + raise ValueError(f"VARCHAR 字段 {f['name']} 必须指定有效的 max_length") + params["max_length"] = f["max_length"] + if dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: + if "dim" not in f or not isinstance(f["dim"], int) or f["dim"] <= 0: + debug(f"vector 字段 {f['name']} 必须指定有效的 dim") + raise ValueError(f"vector 字段 {f['name']} 必须指定有效的 dim") + params["dim"] = f["dim"] + converted_fields.append(FieldSchema(**params)) + fields = converted_fields + debug("成功将 JSON fields 转换为 FieldSchema 对象") + except KeyError as e: + debug(f"fields 中的 JSON 格式缺少必要字段: {str(e)}") + raise ValueError(f"fields 中的 JSON 格式缺少必要字段: {str(e)}") + + # 验证 fields + if not isinstance(fields, list) or not all(isinstance(f, FieldSchema) for f in fields): + raise ValueError("fields 必须是 FieldSchema 对象的列表") + + # 检查 primary key + primary_keys = [f for f in fields if f.is_primary] + if len(primary_keys) != 1: + raise ValueError("fields 必须正好包含一个 primary key 字段") + primary_key = primary_keys[0] + if primary_key.dtype not in [DataType.INT64, DataType.VARCHAR]: + raise ValueError("primary key 字段类型必须是 INT64 或 VARCHAR") + if primary_key.dtype == DataType.VARCHAR and not hasattr(primary_key, 'max_length'): + raise ValueError("VARCHAR primary key 必须指定 max_length") + + # 检查 vector 字段 + vector_fields = [f for f in fields if f.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]] + if not vector_fields: + raise ValueError("fields 必须包含至少一个 vector 字段") + for vf in vector_fields: + if 'dim' not in vf.params or not isinstance(vf.params['dim'], int) or vf.params['dim'] <= 0: + raise ValueError(f"vector 字段 {vf.name} 必须指定有效的 dim 参数") + + # 检查字段名称唯一性 + field_names = [f.name for f in fields] + if len(field_names) != len(set(field_names)): + raise ValueError("fields 中的 name 必须唯一") + + # 检查其他字段约束 + for f in fields: + if f.dtype == DataType.VARCHAR and not hasattr(f, 'max_length'): + raise ValueError(f"VARCHAR 字段 {f.name} 必须指定 max_length") + if f.name == "" or len(f.name) > 255: + raise ValueError(f"字段名称 {f.name} 无效(不能为空或超过 255 字符)") + + debug(f"fields 验证通过,字段数量: {len(fields)}") + + # 创建集合 + schema = CollectionSchema( + fields=fields, + description="统一数据集合,包含用户ID、知识库ID、document_id 和元数据字段", + auto_id=any(f.auto_id for f in primary_keys), + primary_field=primary_key.name, + ) + + collection = Collection(collection_name, schema) + # 创建 vector 索引 + for f in vector_fields: + collection.create_index( + field_name=f.name, + index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"} + ) + # 创建 scalar 索引(可选) + for f in fields: + if f.dtype not in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR, DataType.JSON]: + collection.create_index( + field_name=f.name, + index_params={"index_type": "INVERTED"} + ) + collection.load() + debug(f"成功创建并加载集合: {collection_name}") + return { + "status": "success", + "collection_name": collection_name, + "message": f"集合 {collection_name} 创建成功" + } except Exception as e: error(f"创建集合失败: {str(e)}") return { @@ -285,14 +345,18 @@ class MilvusDBConnection(BaseDBConnection): } async def _insert_document(self, chunks: List[Dict], db_type: str = "") -> Dict[str, Any]: - """插入文档到 Milvus""" + """插入文档到 Milvus,支持动态 schema""" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" - document_id = chunks[0]["document_id"] if chunks else "" + document_id = chunks[0].get("pk", chunks[0].get("document_id", "")) if chunks else "" try: - # 检查集合是否存在 + debug(f"_insert_document called with db_type={db_type}, chunks_count={len(chunks) if chunks else 0}") + + # 检查集合是否存在 - 不传递 fields,依赖 _create_collection 的逻辑 create_result = await self._create_collection(db_type) if create_result["status"] == "error": raise RuntimeError(f"集合创建失败: {create_result['message']}") + collection_name = create_result["collection_name"] + debug(f"Collection ready: {collection_name}") # 检查输入数据 if not chunks: @@ -300,87 +364,55 @@ class MilvusDBConnection(BaseDBConnection): for chunk in chunks: if not isinstance(chunk, dict): raise ValueError("每个 chunk 必须是一个字典") - required_fields = ["text", "vector", "document_id", "filename", "file_path", "upload_time", "file_type", - "userid", "knowledge_base_id"] - if not all(k in chunk for k in required_fields): - raise ValueError(f"chunk 缺少必要字段: {', '.join(set(required_fields) - set(chunk.keys()))}") - if not isinstance(chunk["vector"], list) or len(chunk["vector"]) != 1024: - raise ValueError("vector 必须是长度为 1024 的浮点数列表") - # 验证 userid 和 knowledge_base_id 一致性 - if len(set(chunk["userid"] for chunk in chunks)) > 1: - raise ValueError("所有 chunk 的 userid 必须一致") - if len(set(chunk["knowledge_base_id"] for chunk in chunks)) > 1: - raise ValueError("所有 chunk 的 knowledge_base_id 必须一致") - if len(set(chunk["filename"] for chunk in chunks)) > 1: - raise ValueError("所有 chunk 的 filename 必须一致") + # # 字段映射 + # for chunk in chunks: + # if 'orgid' in chunk and 'userid' not in chunk: + # chunk['userid'] = chunk['orgid'] + # if 'fiid' in chunk and 'knowledge_base_id' not in chunk: + # chunk['knowledge_base_id'] = chunk['fiid'] + # if 'embedding' in chunk and 'vector' not in chunk: + # chunk['vector'] = chunk['embedding'] + # if 'pk' not in chunk: + # chunk['pk'] = chunk.get('document_id', str(uuid.uuid4())) + # debug(f"Sample chunk after mapping: {chunks[0] if chunks else {} }") - # 检查是否已存在相同的 userid、knowledge_base_id 和 filename + # 加载集合并获取 schema collection = Collection(collection_name) collection.load() - # expr = f"userid == '{chunks[0]['userid']}' and knowledge_base_id == '{chunks[0]['knowledge_base_id']}' and filename == '{chunks[0]['filename']}'" - # debug(f"检查重复文档: {expr}") - # results = collection.query(expr=expr, output_fields=["document_id"], limit=1) - # if results: - # debug( - # f"找到重复文档: userid={chunks[0]['userid']}, knowledge_base_id={chunks[0]['knowledge_base_id']}, filename={chunks[0]['filename']}") - # return { - # "status": "error", - # "document_id": document_id, - # "collection_name": collection_name, - # "message": f"文档已存在: userid={chunks[0]['userid']}, knowledge_base_id={chunks[0]['knowledge_base_id']}, filename={chunks[0]['filename']}", - # "status_code": 400 - # } + schema_fields = [field.name for field in collection.schema.fields if not field.auto_id] + vector_fields = [field.name for field in collection.schema.fields if + field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]] + debug(f"Schema fields: {schema_fields}, Vector fields: {vector_fields}") - # 提取数据 - userids = [chunk["userid"] for chunk in chunks] - knowledge_base_ids = [chunk["knowledge_base_id"] for chunk in chunks] - texts = [chunk["text"] for chunk in chunks] - embeddings = [chunk["vector"] for chunk in chunks] - document_ids = [chunk["document_id"] for chunk in chunks] - filenames = [chunk["filename"] for chunk in chunks] - file_paths = [chunk["file_path"] for chunk in chunks] - upload_times = [chunk["upload_time"] for chunk in chunks] - file_types = [chunk["file_type"] for chunk in chunks] + # 验证 chunks 是否包含所有必要字段 + for chunk in chunks: + missing_fields = [f for f in schema_fields if f not in chunk] + if missing_fields: + raise ValueError(f"chunk 缺少必要字段: {', '.join(missing_fields)}") + for vector_field in vector_fields: + vector_dim = next( + field.params["dim"] for field in collection.schema.fields if field.name == vector_field) + if not isinstance(chunk[vector_field], list) or len(chunk[vector_field]) != vector_dim: + raise ValueError(f"{vector_field} 必须是长度为 {vector_dim} 的浮点数列表") # 构造插入数据 - data = { - "userid": userids, - "knowledge_base_id": knowledge_base_ids, - "document_id": document_ids, - "text": texts, - "vector": embeddings, - "filename": filenames, - "file_path": file_paths, - "upload_time": upload_times, - "file_type": file_types, - } + data = {field: [] for field in schema_fields} + for chunk in chunks: + for field in schema_fields: + data[field].append(chunk[field]) - schema_fields = [field.name for field in collection.schema.fields if field.name != "pk"] - debug(f"Schema fields: {schema_fields}") - debug(f"Data keys: {list(data.keys())}") - if list(data.keys()) != schema_fields: - raise ValueError(f"数据字段顺序不匹配,期望: {schema_fields}, 实际: {list(data.keys())}") - - collection.insert([data[field.name] for field in collection.schema.fields if field.name != "pk"]) + # 插入数据 + collection.insert([data[field] for field in schema_fields]) collection.flush() - debug(f"成功插入 {len(texts)} 个文档到集合 {collection_name}") + debug(f"成功插入 {len(chunks)} 个文档到集合 {collection_name}") return { "status": "success", "document_id": document_id, "collection_name": collection_name, - "message": f"成功插入 {len(texts)} 个文档到 {collection_name}", + "message": f"成功插入 {len(chunks)} 个文档到 {collection_name}", "status_code": 200 } - # except MilvusException as e: - # error(f"Milvus 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}") - # return { - # "status": "error", - # "document_id": document_id, - # "collection_name": collection_name, - # "message": f"Milvus 插入失败: {str(e)}", - # "status_code": 400 - # } except Exception as e: error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") return { @@ -391,10 +423,10 @@ class MilvusDBConnection(BaseDBConnection): "status_code": 400 } - async def _delete_document(self, userid: str, file_path: str, knowledge_base_id: str, document_id:str, db_type: str = "") -> Dict[ + async def _delete_document(self, userid: str, knowledge_base_id: str, document_id:str, file_path: "", db_type: str = "") -> Dict[ str, Any]: """删除用户指定文件数据,仅处理 Milvus 记录""" - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + collection_name = "ragdb" if not db_type or db_type == "ragdb" else f"ragdb_{db_type}" try: if not utility.has_collection(collection_name): debug(f"集合 {collection_name} 不存在") @@ -464,8 +496,18 @@ class MilvusDBConnection(BaseDBConnection): async def _delete_knowledge_base(self, db_type: str, userid: str, knowledge_base_id: str) -> Dict[str, Any]: """删除用户的整个知识库,仅处理 Milvus 记录""" - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + collection_name = "ragdb" if not db_type or db_type == "ragdb" else f"ragdb_{db_type}" try: + if not userid: + raise ValueError("userid 不能为空") + if not knowledge_base_id: + raise ValueError("knowledge_base_id 不能为空") + if len(userid) > 100 or len(knowledge_base_id) > 100 or (db_type and len(db_type) > 100): + raise ValueError("userid, knowledge_base_id 或 db_type 的长度超出限制") + + debug( + f"删除知识库参数: userid={userid}, knowledge_base_id={knowledge_base_id}, db_type={db_type}, collection_name={collection_name}") + if not utility.has_collection(collection_name): debug(f"集合 {collection_name} 不存在") return { @@ -478,6 +520,7 @@ class MilvusDBConnection(BaseDBConnection): try: collection = Collection(collection_name) + collection.load() debug(f"加载集合: {collection_name}") except Exception as e: error(f"加载集合 {collection_name} 失败: {str(e)}") @@ -489,29 +532,29 @@ class MilvusDBConnection(BaseDBConnection): "status_code": 400 } + # 检查集合 schema + schema_fields = [field.name for field in collection.schema.fields] + debug(f"集合 {collection_name} schema: {schema_fields}") + deleted_files = [] - try: - expr = f"userid == '{userid}' and knowledge_base_id == '{knowledge_base_id}'" - debug(f"查询表达式: {expr}") - results = collection.query( - expr=expr, - output_fields=["file_path"], - limit=1000 - ) - if results: - deleted_files = list(set(result["file_path"] for result in results if "file_path" in result)) - debug(f"找到 {len(deleted_files)} 个唯一文件: {deleted_files}") - else: - debug(f"没有找到 userid={userid}, knowledge_base_id={knowledge_base_id} 的记录") - except Exception as e: - error(f"查询 file_path 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "deleted_files": [], - "message": f"查询 file_path 失败: {str(e)}", - "status_code": 400 - } + if collection_name == "ragdb" and "file_path" in schema_fields: + try: + expr = f"userid == '{userid}' and knowledge_base_id == '{knowledge_base_id}'" + debug(f"查询表达式: {expr}") + results = collection.query( + expr=expr, + output_fields=["file_path"], + limit=1000 + ) + if results: + deleted_files = list(set(result["file_path"] for result in results if result["file_path"])) + debug(f"找到 {len(deleted_files)} 个唯一文件: {deleted_files}") + else: + debug(f"没有找到 userid={userid}, knowledge_base_id={knowledge_base_id} 的记录") + except Exception as e: + error(f"查询 file_path 失败: {str(e)}") + # 继续执行删除操作,不因查询失败而中断 + deleted_files = [] total_deleted = 0 try: @@ -561,18 +604,16 @@ class MilvusDBConnection(BaseDBConnection): } async def _search_query(self, query_vector: List[float], userid: str, - knowledge_base_ids: List[str], limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]: - """基于向量搜索 Milvus 集合""" - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + knowledge_base_ids: List[str] | None, limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]: + """基于向量搜索 Milvus 集合(knowledge_base_ids 可为空)""" timing_stats = {} start_time = time.time() try: + # ---------- 参数校验 ---------- if not query_vector or not isinstance(query_vector, list) or len(query_vector) != 1024: raise ValueError("query_vector 必须是长度为 1024 的浮点数列表") if not userid: raise ValueError("userid 不能为空") - if not knowledge_base_ids: - raise ValueError("knowledge_base_ids 不能为空") if len(userid) > 100: raise ValueError("userid 的长度超出限制") if limit <= 0 or limit > 16384: @@ -581,80 +622,117 @@ class MilvusDBConnection(BaseDBConnection): raise ValueError("offset 不能为负数") if limit + offset > 16384: raise ValueError("limit + offset 不能超过 16384") - for kb_id in knowledge_base_ids: - if not isinstance(kb_id, str): - raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}") - if len(kb_id) > 100: - raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}") + if knowledge_base_ids is not None: + for kb_id in knowledge_base_ids: + if not isinstance(kb_id, str) or len(kb_id) > 100: + raise ValueError(f"knowledge_base_id 必须是字符串且长度 ≤ 100: {kb_id}") + + # ---------- 集合名称 ---------- + collection_name = "ragdb" if not db_type or db_type == "ragdb" else f"ragdb_{db_type}" + debug(f"目标集合: {collection_name}") if not utility.has_collection(collection_name): debug(f"集合 {collection_name} 不存在") return {"results": [], "timing": timing_stats} + # ---------- 加载集合 ---------- try: collection = Collection(collection_name) collection.load() - debug(f"加载集合: {collection_name}") timing_stats["collection_load"] = time.time() - start_time - debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒") except Exception as e: error(f"加载集合 {collection_name} 失败: {str(e)}") return {"results": [], "timing": timing_stats} + # ---------- 动态 schema ---------- + schema_fields = [f.name for f in collection.schema.fields] + is_ragdb = collection_name == "ragdb" + output_fields = ( + ["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type", + "knowledge_base_id"] + if is_ragdb else + ["document_id", "userid", "knowledge_base_id", "text"] + ) + output_fields = [f for f in output_fields if f in schema_fields] + + # ---------- 动态构建表达式 ---------- + if is_ragdb: + expr_parts = [f"userid == '{userid}'"] + if knowledge_base_ids: + kb_expr = " or ".join([f"knowledge_base_id == '{kb}'" for kb in knowledge_base_ids]) + expr_parts.append(f"({kb_expr})") + expr = " and ".join(expr_parts) + else: + expr_parts = [f"userid == '{userid}'"] + if knowledge_base_ids: + kb_expr = " or ".join([f"knowledge_base_id == '{kb}'" for kb in knowledge_base_ids]) + expr_parts.append(f"({kb_expr})") + expr = " and ".join(expr_parts) + + debug(f"搜索表达式: {expr}") + + # ---------- 保留你指定的 search 调用 ---------- search_start = time.time() search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) - expr = f"userid == '{userid}' and ({kb_id_expr})" - debug(f"搜索表达式: {expr}") try: results = collection.search( data=[query_vector], anns_field="vector", param=search_params, - limit=100, - expr=expr, - output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", - "file_type"], - offset=offset + limit=100, # 保留 + expr=expr, # 动态 + output_fields=output_fields, # 动态 + offset=offset # 保留 ) except Exception as e: error(f"搜索失败: {str(e)}") return {"results": [], "timing": timing_stats} - timing_stats["vector_search"] = time.time() - search_start - debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") - search_results = [] + timing_stats["vector_search"] = time.time() - search_start + + # ---------- 统一结果结构 + 去重 ---------- + raw_results = [] for hits in results: for hit in hits: - metadata = { - "userid": hit.entity.get("userid"), - "document_id": hit.entity.get("document_id"), - "filename": hit.entity.get("filename"), - "file_path": hit.entity.get("file_path"), - "upload_time": hit.entity.get("upload_time"), - "file_type": hit.entity.get("file_type") - } - result = { - "text": hit.entity.get("text"), - "distance": hit.distance, - "source": "vector_query", - "metadata": metadata - } - search_results.append(result) - debug( - f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}") + if is_ragdb: + metadata = { + "userid": hit.entity.get("userid"), + "document_id": hit.entity.get("document_id"), + "filename": hit.entity.get("filename"), + "file_path": hit.entity.get("file_path"), + "upload_time": hit.entity.get("upload_time"), + "file_type": hit.entity.get("file_type"), + "knowledge_base_id": hit.entity.get("knowledge_base_id") + } + raw_results.append({ + "text": hit.entity.get("text"), + "distance": 1.0 - hit.distance, + "metadata": metadata + }) + else: + metadata = { + "document_id": hit.entity.get("document_id"), + "userid": hit.entity.get("userid"), + "knowledge_base_id": hit.entity.get("knowledge_base_id"), + } + raw_results.append({ + "text": hit.entity.get("text"), + "distance": 1.0 - hit.distance, + "metadata": metadata + }) + # ---------- 去重+ limit ---------- dedup_start = time.time() unique_results = [] seen_texts = set() - for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): + for result in sorted(raw_results, key=lambda x: x['distance'], reverse=True): if result['text'] not in seen_texts: unique_results.append(result) seen_texts.add(result['text']) timing_stats["deduplication"] = time.time() - dedup_start debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒") - info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") + info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(raw_results)})") timing_stats["total_time"] = time.time() - start_time info(f"向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") @@ -664,163 +742,316 @@ class MilvusDBConnection(BaseDBConnection): error(f"向量搜索失败: {str(e)}") return {"results": [], "timing": timing_stats} - async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, List[Dict]]: - """列出用户的所有知识库及其文件,按 knowledge_base_id 分组""" - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + # async def _search_query(self, query_vector: List[float], userid: str, + # knowledge_base_ids: List[str], limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]: + # """基于向量搜索 Milvus 集合""" + # collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + # timing_stats = {} + # start_time = time.time() + # try: + # if not query_vector or not isinstance(query_vector, list) or len(query_vector) != 1024: + # raise ValueError("query_vector 必须是长度为 1024 的浮点数列表") + # if not userid: + # raise ValueError("userid 不能为空") + # if not knowledge_base_ids: + # raise ValueError("knowledge_base_ids 不能为空") + # if len(userid) > 100: + # raise ValueError("userid 的长度超出限制") + # if limit <= 0 or limit > 16384: + # raise ValueError("limit 必须在 1 到 16384 之间") + # if offset < 0: + # raise ValueError("offset 不能为负数") + # if limit + offset > 16384: + # raise ValueError("limit + offset 不能超过 16384") + # for kb_id in knowledge_base_ids: + # if not isinstance(kb_id, str): + # raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}") + # if len(kb_id) > 100: + # raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}") + # + # if not utility.has_collection(collection_name): + # debug(f"集合 {collection_name} 不存在") + # return {"results": [], "timing": timing_stats} + # + # try: + # collection = Collection(collection_name) + # collection.load() + # debug(f"加载集合: {collection_name}") + # timing_stats["collection_load"] = time.time() - start_time + # debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒") + # except Exception as e: + # error(f"加载集合 {collection_name} 失败: {str(e)}") + # return {"results": [], "timing": timing_stats} + # + # search_start = time.time() + # search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} + # kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) + # expr = f"userid == '{userid}' and ({kb_id_expr})" + # debug(f"搜索表达式: {expr}") + # + # try: + # results = collection.search( + # data=[query_vector], + # anns_field="vector", + # param=search_params, + # limit=100, + # expr=expr, + # output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", + # "file_type"], + # offset=offset + # ) + # except Exception as e: + # error(f"搜索失败: {str(e)}") + # return {"results": [], "timing": timing_stats} + # timing_stats["vector_search"] = time.time() - search_start + # debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") + # + # search_results = [] + # for hits in results: + # for hit in hits: + # metadata = { + # "userid": hit.entity.get("userid"), + # "document_id": hit.entity.get("document_id"), + # "filename": hit.entity.get("filename"), + # "file_path": hit.entity.get("file_path"), + # "upload_time": hit.entity.get("upload_time"), + # "file_type": hit.entity.get("file_type") + # } + # result = { + # "text": hit.entity.get("text"), + # "distance": hit.distance, + # "source": "vector_query", + # "metadata": metadata + # } + # search_results.append(result) + # debug( + # f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}") + # + # dedup_start = time.time() + # unique_results = [] + # seen_texts = set() + # for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): + # if result['text'] not in seen_texts: + # unique_results.append(result) + # seen_texts.add(result['text']) + # timing_stats["deduplication"] = time.time() - dedup_start + # debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒") + # info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") + # + # timing_stats["total_time"] = time.time() - start_time + # info(f"向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") + # return {"results": unique_results[:limit], "timing": timing_stats} + # + # except Exception as e: + # error(f"向量搜索失败: {str(e)}") + # return {"results": [], "timing": timing_stats} + + async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, Any]: + """列出用户的所有知识库及其文件,按 knowledge_base_id 分组(仅 ragdb)""" try: - info(f"列出用户文件: userid={userid}, db_type={db_type}") + debug(f"列出用户文件: userid={userid}, db_type={db_type}") if not userid: raise ValueError("userid 不能为空") if (db_type and len(db_type) > 100) or len(userid) > 100: raise ValueError("userid 或 db_type 的长度超出限制") - if not utility.has_collection(collection_name): - debug(f"集合 {collection_name} 不存在") - return {} - - try: - collection = Collection(collection_name) - collection.load() - debug(f"加载集合: {collection_name}") - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}") - return {} - - expr = f"userid == '{userid}'" - debug(f"查询表达式: {expr}") - - try: - results = collection.query( - expr=expr, - output_fields=["document_id", "filename", "file_path", "upload_time", "file_type", "knowledge_base_id"], - limit=1000 - ) - except Exception as e: - error(f"查询用户文件失败: {str(e)}") - return {} - - files_by_kb = {} - seen_document_ids = set() - for result in results: - document_id = result.get("document_id") - kb_id = result.get("knowledge_base_id") - if document_id not in seen_document_ids: - seen_document_ids.add(document_id) - file_info = { - "document_id": document_id, - "filename": result.get("filename"), - "file_path": result.get("file_path"), - "upload_time": result.get("upload_time"), - "file_type": result.get("file_type"), - "knowledge_base_id": kb_id - } - if kb_id not in files_by_kb: - files_by_kb[kb_id] = [] - files_by_kb[kb_id].append(file_info) - debug(f"找到文件: document_id={document_id}, filename={result.get('filename')}, knowledge_base_id={kb_id}") - - info(f"找到 {len(seen_document_ids)} 个文件,userid={userid}, 知识库数量={len(files_by_kb)}") - return files_by_kb - - except Exception as e: - error(f"列出用户文件失败: {str(e)}") - return {} - - async def _list_all_knowledge_bases(self, db_type: str = "") -> Dict[str, Any]: - """列出数据库中所有用户的知识库及其文件,按用户分组""" - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" - try: - info(f"列出所有用户的知识库: db_type={db_type}") - - if db_type and "_" in db_type: - raise ValueError("db_type 不能包含下划线") - if db_type and len(db_type) > 100: - raise ValueError("db_type 的长度应小于 100") + collection_name = "ragdb" if not db_type or db_type == "ragdb" else f"ragdb_{db_type}" + debug(f"目标集合: {collection_name}") if not utility.has_collection(collection_name): debug(f"集合 {collection_name} 不存在") return { "status": "success", - "users_knowledge_bases": {}, - "collection_name": collection_name, - "message": f"集合 {collection_name} 不存在", - "status_code": 200 + "files_by_knowledge_base": {} if collection_name == "ragdb" else [], + "collection_name": collection_name } try: collection = Collection(collection_name) collection.load() - debug(f"加载集合: {collection_name}") except Exception as e: error(f"加载集合 {collection_name} 失败: {str(e)}") - return { - "status": "error", - "users_knowledge_bases": {}, - "collection_name": collection_name, - "message": f"加载集合失败: {str(e)}", - "status_code": 400 - } + return {"status": "error", "message": str(e)} - expr = "userid != ''" - debug(f"查询表达式: {expr}") - try: - results = collection.query( - expr=expr, - output_fields=["userid", "knowledge_base_id", "document_id", "filename", "file_path", "upload_time", - "file_type"], - limit=10000 - ) - except Exception as e: - error(f"查询所有用户文件失败: {str(e)}") - return { - "status": "error", - "users_knowledge_bases": {}, - "collection_name": collection_name, - "message": f"查询失败: {str(e)}", - "status_code": 400 - } + schema_fields = [f.name for f in collection.schema.fields] + debug(f"集合 {collection_name} schema: {schema_fields}") + + is_ragdb = collection_name == "ragdb" + output_fields = ( + ["userid", "document_id", "filename", "file_path", "upload_time", "file_type", "knowledge_base_id"] + if is_ragdb else + ["document_id", "userid", "knowledge_base_id", "text"] + ) + output_fields = [f for f in output_fields if f in schema_fields] + + # 动态构建表达式 + if "userid" in schema_fields: + expr = f"userid == '{userid}'" + debug(f"使用顶级字段过滤: {expr}") + else: + expr = f'metadata["user_id"] == "{userid}"' + debug(f"使用 JSON 嵌套字段过滤: {expr}") + + results = collection.query(expr=expr, output_fields=output_fields, limit=1000) + + # 正确初始化(不嵌套) + if is_ragdb: + files_by_knowledge_base = {} # {kb_id: [file,...]} + else: + files_by_knowledge_base = [] # [file1, file2, ...] - users_knowledge_bases = {} seen_document_ids = set() + for result in results: - userid = result.get("userid") + document_id = result.get("document_id") or result.get("id") + if not document_id or document_id in seen_document_ids: + continue + seen_document_ids.add(document_id) + kb_id = result.get("knowledge_base_id") - document_id = result.get("document_id") - if document_id not in seen_document_ids: - seen_document_ids.add(document_id) + if is_ragdb and not kb_id: + continue + + if is_ragdb: file_info = { + "userid": result.get("userid"), "document_id": document_id, - "filename": result.get("filename"), - "file_path": result.get("file_path"), - "upload_time": result.get("upload_time"), - "file_type": result.get("file_type"), + "filename": result.get("filename", ""), + "file_path": result.get("file_path", ""), + "upload_time": result.get("upload_time", ""), + "file_type": result.get("file_type", ""), "knowledge_base_id": kb_id } - if userid not in users_knowledge_bases: - users_knowledge_bases[userid] = {} - if kb_id not in users_knowledge_bases[userid]: - users_knowledge_bases[userid][kb_id] = [] - users_knowledge_bases[userid][kb_id].append(file_info) - debug( - f"找到文件: userid={userid}, knowledge_base_id={kb_id}, document_id={document_id}, filename={result.get('filename')}") + files_by_knowledge_base.setdefault(kb_id, []).append(file_info) + else: + file_info = { + "document_id": document_id, + "userid": result.get("userid", ""), + "knowledge_base_id": kb_id, + "text": result.get("text", ""), + } + files_by_knowledge_base.append(file_info) - info(f"找到 {len(seen_document_ids)} 个文件,涉及 {len(users_knowledge_bases)} 个用户") + debug(f"找到文件: document_id={document_id}, kb_id={kb_id}, file_info={file_info}") + + info(f"找到 {len(seen_document_ids)} 个文件") return { "status": "success", - "users_knowledge_bases": users_knowledge_bases, - "collection_name": collection_name, - "message": f"成功列出 {len(users_knowledge_bases)} 个用户的知识库和文件", + "files_by_knowledge_base": files_by_knowledge_base, + "collection_name": collection_name + } + + except Exception as e: + error(f"列出用户文件失败: {str(e)}") + return {"status": "error", "message": str(e)} + + async def _list_all_knowledge_bases(self, db_type: str = "") -> Dict[str, Any]: + try: + debug(f"列出所有知识库: db_type={db_type}") + + collection_names = utility.list_collections() + if db_type: + target = "ragdb" if db_type == "ragdb" else f"ragdb_{db_type}" + collection_names = [n for n in collection_names if n == target] + else: + collection_names = [n for n in collection_names if n.startswith("ragdb")] + + if not collection_names: + return { + "status": "success", + "collections": {}, + "collection_names": "none", + "message": "未找到任何以 ragdb 开头的集合", + "status_code": 200 + } + + collections = {} + total_documents = 0 + skipped_collections = [] + + for collection_name in collection_names: + if collection_name.startswith("ragdb_"): + db_part = collection_name[len("ragdb_"):] + if "_" in db_part or len(db_part) > 100: + skipped_collections.append(collection_name) + continue + + try: + collection = Collection(collection_name) + collection.load() + except Exception as e: + error(f"加载集合 {collection_name} 失败: {str(e)}") + skipped_collections.append(collection_name) + continue + + schema_fields = [f.name for f in collection.schema.fields] + debug(f"集合{collection_name}的schme是:{schema_fields}") + is_ragdb = collection_name == "ragdb" + output_fields = ( + ["document_id", "knowledge_base_id", "filename", "text", "userid"] + if is_ragdb else + ["document_id", "userid", "knowledge_base_id", "text"] + ) + output_fields = [f for f in output_fields if f in schema_fields] + + results = collection.query(expr="", output_fields=output_fields, limit=10000) + debug(f"集合{collection_name}的查询到{len(results)}条记录") + for r in results: + debug(f"集合{collection_name}找到文档:{r}") + # 正确初始化 + if is_ragdb: + collections[collection_name] = {"knowledge_bases": {}} + else: + collections[collection_name] = {"knowledge_bases": []} + + seen = set() + for result in results: + doc_id = result.get("document_id") or result.get("id") + if not doc_id or doc_id in seen: + continue + seen.add(doc_id) + + kb_id = result.get("knowledge_base_id") + if is_ragdb and not kb_id: + continue + + if is_ragdb: + file_info = { + "document_id": doc_id, + "text": result.get("text", "")[:50], + "userid": result.get("userid", ""), + "filename": result.get("filename", ""), + "knowledge_base_id": kb_id + } + collections[collection_name]["knowledge_bases"].setdefault(kb_id, []).append(file_info) + else: + file_info = { + "document_id": doc_id, + "userid": result.get("userid", ""), + "knowledge_base_id": kb_id, + "text": result.get("text", ""), + } + collections[collection_name]["knowledge_bases"].append(file_info) + + total_documents += 1 + + debug(f"找到 {total_documents} 个文档,跨 {len(collections)} 个集合") + return { + "status": "success", + "collections": collections, + "collection_names": ",".join(collection_names), + "message": f"成功列出 {len(collections)} 个集合,{total_documents} 个文档", "status_code": 200 } except Exception as e: - error(f"列出所有用户知识库失败: {str(e)}") + error(f"列出所有知识库失败: {str(e)}") return { "status": "error", - "users_knowledge_bases": {}, - "collection_name": collection_name, - "message": f"列出所有用户知识库失败: {str(e)}", + "collections": {}, + "collection_names": "none", + "message": f"列出所有知识库失败: {str(e)}", "status_code": 400 } diff --git a/llmengine/mm_embedding.py b/llmengine/mm_embedding.py index 14bf100..0d9cd33 100644 --- a/llmengine/mm_embedding.py +++ b/llmengine/mm_embedding.py @@ -1,275 +1,361 @@ -# embed_all_unified.py -""" -Unified multimodal embedder (text, image, video, audio) -Features: -- All modalities mapped to the same embedding space (CLIP or CLAP) -- GPU/CPU/MPS auto detection -- FP16 autocast for speed -- Batch processing -- Video frame sampling + average pooling -- Audio resampling + CLAP embedding -- L2 normalized output for similarity search - -model_name='/data/ymq/models/laion/CLIP-ViT-B-32-laion2B-s34B-b79K' - -impput: - -text: -{ - "type":"text, - "text":"...." -} -image: -""" - - import os -from pathlib import Path +import argparse import numpy as np import torch from PIL import Image import av import librosa -from concurrent.futures import ThreadPoolExecutor +from pathlib import Path from math import ceil -from appPublic.jsonConfig import getConfig -from appPublic.worker import awaitify -from ahserver.webapp import webapp -from ahserver.serverenv import ServerEnv +from traceback import format_exc +from appPublic.jsonConfig import getConfig +from appPublic.registerfunction import RegisterFunction +from appPublic.log import debug, exception + +from ahserver.serverenv import ServerEnv +from ahserver.webapp import webserver + +from transformers import CLIPProcessor, CLIPModel +from sklearn.preprocessing import normalize +from sklearn.cluster import DBSCAN +import base64 +import traceback +import uuid + +helptext = """CLIP 多模态统一嵌入服务 +API 地址: +POST http://localhost:8883/v1/embed +功能: +将 文本 / 图片 / 视频 / 音频 统一转为 1024 维 L2 归一化向量 +支持人脸检测 + 去重(基于 DBSCAN 余弦距离) + +输入格式 (JSON): +{ + "inputs": [ + "文本字符串", + "/path/to/image.jpg", + "/path/to/video.mp4", + "/path/to/audio.wav" + ] +} + +输出格式: +{ + "data": { ... }, + "object": "embedding.result", + "model": "CLIP-ViT-H-14-laion2B-s32B-b79K" +} + +特性: +- 自动识别文件类型(文本/图像/音频/视频) +- 视频抽帧 + 人脸去重 +- 文件不存在 → 自动降级为文本嵌入 +- 所有向量 L2 归一化,可直接余弦相似度 +- 人脸向量支持聚类去重(eps=0.4) + +文档查看: +GET http://localhost:8883/docs +""" try: import face_recognition FACE_LIB_AVAILABLE = True except Exception: + debug('人脸识别库导入失败') FACE_LIB_AVAILABLE = False -# ------------------- Configuration ------------------- +# ------------------- 配置 ------------------- DEVICE = "cuda" if torch.cuda.is_available() else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu" USE_FP16 = DEVICE == "cuda" - -def choose_device(): - if torch.cuda.is_available(): - return "cuda" - if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): - return "mps" - return "cpu" - -# Unified model for all modalities -CLIP_MODEL_NAME = "openai/clip-vit-large-patch14" -FRAME_SAMPLE_RATE = 1.0 # fps for video +CLIP_MODEL_NAME = "/data/ymq/models/laion/CLIP-ViT-H-14-laion2B-s32B-b79K" +FRAME_SAMPLE_RATE = 1.0 FRAME_LIMIT = 64 -AUDIO_SR = 16000 # resample audio +AUDIO_SR = 16000 +IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image") +AUDIO_DIR = Path("/share/wangmeihua/data/mmembedding/audio") +VIDEO_DIR = Path("/share/wangmeihua/data/mmembedding/video") -# ------------------- Load model ------------------- -from transformers import CLIPProcessor, CLIPModel - -# ------------------- Utils ------------------- - -def deduplicate_faces(face_embeddings, eps=0.4, min_samples=2): - emb_norm = normalize(face_embeddings) - clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="cosine").fit(emb_norm) - unique_faces = [] - for label in set(clustering.labels_): - if label == -1: # 噪声 - continue - cluster_embs = emb_norm[clustering.labels_ == label] - unique_faces.append(np.mean(cluster_embs, axis=0)) - return np.array(unique_faces) +for d in [IMAGE_DIR, AUDIO_DIR, VIDEO_DIR]: + d.mkdir(parents=True, exist_ok=True) +# ------------------- 工具函数 ------------------- def l2_normalize(v): norm = np.linalg.norm(v) return v / norm if norm > 1e-10 else v def chunked(lst, n): for i in range(0, len(lst), n): - yield lst[i:i+n] + yield lst[i:i + n] +def deduplicate_faces(face_embeddings, eps=0.4, min_samples=2): + if len(face_embeddings) == 0: + return np.array([]) + emb_norm = normalize(face_embeddings) + clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="cosine").fit(emb_norm) + unique = [] + for label in set(clustering.labels_): + if label == -1: + continue + cluster = emb_norm[clustering.labels_ == label] + unique.append(np.mean(cluster, axis=0)) + return np.array(unique) + +# ------------------- 主模型类 ------------------- class MM_Embedding: - def __init__(self, model_name): - self.model = CLIPModel.from_pretrained(model_name).to(DEVICE) - self.processor = CLIPProcessor.from_pretrained(model_name) - if USE_FP16: - self.model.half() + def __init__(self, model_name): + debug(f"Loading CLIP model: {model_name}") + self.model_name = Path(model_name).name + self.model = CLIPModel.from_pretrained( + model_name, + torch_dtype=torch.float16 if USE_FP16 else torch.float32, + device_map="auto" if USE_FP16 else None + ).to(DEVICE).eval() + self.processor = CLIPProcessor.from_pretrained(model_name) - def detect_faces(self, img): - faces = self.extract_faces(img) - face_vecs = self.embed_faces(img) - return face_vecs, faces - - # ------------------- Image ------------------- - def embed_images(self, paths, batch_size=16): - results = {} - for batch in chunked(paths, batch_size): - imgs = [Image.open(p).convert("RGB") for p in batch] - inputs = self.processor(images=imgs, return_tensors="pt", padding=True).to(DEVICE) - with torch.no_grad(): - if USE_FP16: - with torch.cuda.amp.autocast(): - feats = self.model.get_image_features(**inputs) - else: - feats = self.model.get_image_features(**inputs) - feats = feats.cpu().numpy() - faces_list = [] - for img in imgs: - faces_list.append(self.detect_faces(img)) + def embed_batch(self, inputs): + if not isinstance(inputs, list): + raise ValueError("inputs must be a list") + if len(inputs) == 0: + return {} - for p, v, fs in zip(batch, feats, faces_list): - results[p] = { - 'type':'image', - 'path': p, - 'faces': fs[1], - 'face_vecs': fs[0], - 'face_count':len(fs[0]), - 'vector': l2_normalize(v) - } - return results + groups = {"image": [], "video": [], "audio": [], "text": []} + results = {} + for item in inputs: + # ------------------- 新格式:字典 ------------------- + if isinstance(item, dict): + typ = item.get("type") + data_uri = item.get("data") - # ------------------- Text ------------------- - def embed_texts(self, texts, batch_size=64): - results = {} - for batch in chunked(texts, batch_size): - inputs = self.processor(text=batch, return_tensors="pt", padding=True, truncation=True).to(DEVICE) - with torch.no_grad(): - if USE_FP16: - with torch.cuda.amp.autocast(): - feats = self.model.get_text_features(**inputs) - else: - feats = self.model.get_text_features(**inputs) - feats = feats.cpu().numpy() - for t, v in zip(batch, feats): - results[t] = { - "type": "text", - "vector": l2_normalize(v) - } - return results + if typ == "text": + content = item.get("content", "") + if content: + groups["text"].append(content) + continue - # ------------------- Video ------------------- - def embed_videos(self, paths, frame_rate=FRAME_SAMPLE_RATE, frame_limit=FRAME_LIMIT): - results = {} - for p in paths: - container = av.open(p) - frames = [] - fps = float(container.streams.video[0].average_rate) if container.streams.video else 30.0 - step = max(1, int(fps / max(1, frame_rate))) - count = 0 - for i, frame in enumerate(container.decode(video=0)): - if i % step == 0: - frames.append(frame.to_image().convert("RGB")) - count += 1 - if count >= frame_limit: - break - container.close() - if not frames: - results[p] = None - continue - # batch embed - emb_list = [] - faces_list = [] - for batch in chunked(frames, 16): - inputs = self.processor(images=batch, return_tensors="pt", padding=True).to(DEVICE) - with torch.no_grad(): - if USE_FP16: - with torch.cuda.amp.autocast(): - feats = self.model.get_image_features(**inputs) - else: - feats = self.model.get_image_features(**inputs) - for img in batch: - faces_list += self.detect_faces(img)[0] - emb_list.append(feats.cpu().numpy()) - face_vecs = deduplicate_faces(faces_list) - emb_array = np.vstack(emb_list) - video_vec = l2_normalize(emb_array.mean(axis=0)) - # face_vecs = - results[p] = { - "type": "video", - "path": p, - "vector": video_vec, - "face_count": len(face_vecs), - "face_vecs": face_vecs - } - return results + if typ in {"image", "video", "audio"} and data_uri: + try: + # 1. 提取 base64 + try: + header, b64 = data_uri.split(",", 1) + debug(f"header: {header},b64: {b64}") + binary = base64.b64decode(b64) + except Exception as e: + error(f"解码失败: {str(e)}, 堆栈: {traceback.format_exc()}") - # ------------------- Audio ------------------- - def embed_audios(self, paths, batch_size=4): - results = {} - for p in paths: - y, sr = librosa.load(p, sr=AUDIO_SR, mono=True) - # convert to mel spectrogram image - S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=224) - S_db = librosa.power_to_db(S, ref=np.max) - img = Image.fromarray(np.uint8((S_db - S_db.min())/(S_db.max()-S_db.min()+1e-9)*255)).convert("RGB").resize((224,224)) - inputs = self.processor(images=img, return_tensors="pt").to(DEVICE) - with torch.no_grad(): - if USE_FP16: - with torch.cuda.amp.autocast(): - feat = self.model.get_image_features(**inputs) - else: - feat = self.model.get_image_features(**inputs) - results[p] = l2_normalize(feat.cpu().numpy()[0]) - return results + # 2. 确定扩展名 + mime_to_ext = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/webp": ".webp", + "video/mp4": ".mp4", + "video/webm": ".webm", + "audio/mpeg": ".mp3", + "audio/wav": ".wav", + "audio/ogg": ".ogg", + } + mime = header.split(";")[0].split(":")[1] + ext = mime_to_ext.get(mime, ".bin") - def extract_faces(self, img: Image.Image): - """返回裁剪后的人脸区域列表""" + # 3. 生成唯一文件名 + 存储 + uid = uuid.uuid4().hex[:12] + if typ == "image": + save_dir = IMAGE_DIR + fake_path = save_dir / f"{uid}{ext}" + fake_path = str(fake_path) + elif typ == "video": + save_dir = VIDEO_DIR + fake_path = save_dir / f"{uid}{ext}" + fake_path = str(fake_path) + elif typ == "audio": + save_dir = AUDIO_DIR + fake_path = save_dir / f"{uid}{ext}" + fake_path = str(fake_path) + + Path(fake_path).write_bytes(binary) + debug(f"保存多媒体文件: {fake_path} ({len(binary) / 1024 / 1024:.2f}MB)") + + # 4. 放入对应 group(CLIP 直接用路径) + if typ == "image": + groups["image"].append(fake_path) + elif typ == "video": + groups["video"].append(fake_path) + elif typ == "audio": + groups["audio"].append(fake_path) + + # 记录原始来源(可选) + results[fake_path] = {"type": typ, "source": "data_uri", "original_mime": mime} + continue + + except Exception as e: + results[id(item)] = {"type": "error", "error": f"data URI 解码失败: {e}"} + continue + + if groups["image"]: + results.update(self._embed_images(groups["image"])) + if groups["video"]: + results.update(self._embed_videos(groups["video"])) + if groups["audio"]: + results.update(self._embed_audios(groups["audio"])) + if groups["text"]: + results.update(self._embed_texts(groups["text"])) + debug(f"最终返回结果是:{results}") + return results + + def _embed_texts(self, texts): + results = {} + for batch in chunked(texts, 64): + inputs = self.processor(text=batch, return_tensors="pt", padding=True, truncation=True).to(DEVICE) + with torch.no_grad(): + with torch.amp.autocast('cuda', enabled=USE_FP16): + feats = self.model.get_text_features(**inputs) + feats = feats.cpu().numpy() + for t, v in zip(batch, feats): + results[t] = {"type": "text", "vector": l2_normalize(v).tolist()} + return results + + def _embed_images(self, paths): + results = {} + for batch in chunked(paths, 16): + imgs = [Image.open(p).convert("RGB") for p in batch] + inputs = self.processor(images=imgs, return_tensors="pt", padding=True).to(DEVICE) + with torch.no_grad(): + with torch.amp.autocast('cuda', enabled=USE_FP16): + feats = self.model.get_image_features(**inputs) + feats = feats.cpu().numpy() + for p, v, img in zip(batch, feats, imgs): + face_vecs, _ = self._detect_faces(img) + results[p] = { + "type": "image", + "path": p, + "vector": l2_normalize(v).tolist(), + "face_count": len(face_vecs), + "face_vecs": [vec.tolist() for vec in face_vecs] + } + return results + + def _embed_videos(self, paths): + results = {} + for p in paths: + try: + container = av.open(p) + frames = [] + fps = float(container.streams.video[0].average_rate) or 30.0 + step = max(1, int(fps / FRAME_SAMPLE_RATE)) + for i, frame in enumerate(container.decode(video=0)): + if i % step == 0: + frames.append(frame.to_image().convert("RGB")) + if len(frames) >= FRAME_LIMIT: + break + container.close() + + if not frames: + results[p] = None + continue + + emb_list = [] + all_faces = [] + for batch in chunked(frames, 16): + inputs = self.processor(images=batch, return_tensors="pt", padding=True).to(DEVICE) + with torch.no_grad(): + with torch.amp.autocast('cuda', enabled=USE_FP16): + feats = self.model.get_image_features(**inputs) + for img in batch: + fv, _ = self._detect_faces(img) + all_faces.extend(fv) + emb_list.append(feats.cpu().numpy()) + + face_vecs = deduplicate_faces(all_faces) + video_vec = l2_normalize(np.vstack(emb_list).mean(axis=0)) + + results[p] = { + "type": "video", + "path": p, + "vector": video_vec.tolist(), + "face_count": len(face_vecs), + "face_vecs": [vec.tolist() for vec in face_vecs] + } + except Exception as e: + exception(f"Video {p} failed: {e}") + results[p] = None + return results + + def _embed_audios(self, paths): + results = {} + for p in paths: + try: + y, sr = librosa.load(p, sr=AUDIO_SR, mono=True) + S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=224) + S_db = librosa.power_to_db(S, ref=np.max) + norm_val = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-9) + img = Image.fromarray(np.uint8(norm_val * 255)).convert("RGB").resize((224, 224)) + inputs = self.processor(images=img, return_tensors="pt").to(DEVICE) + with torch.no_grad(): + with torch.amp.autocast('cuda', enabled=USE_FP16): + feats = self.model.get_image_features(**inputs) + results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0]).tolist()} + except Exception as e: + exception(f"Audio {p} failed: {e}") + results[p] = None + return results + + def _detect_faces(self, img): arr = np.array(img) - face_locs = face_recognition.face_locations(arr) - faces = [] - for (top, right, bottom, left) in face_locs: - face = arr[top:bottom, left:right] - faces.append(Image.fromarray(face)) - return faces + locs = face_recognition.face_locations(arr) + debug(f'图片的人脸位置信息:{locs}') + encodings = face_recognition.face_encodings(arr, known_face_locations=locs) + vecs = [l2_normalize(np.array(e)) for e in encodings] if encodings else [] + debug(f'图片的人脸向量是:{vecs}') + return vecs, [] - def embed_faces(self, img: Image.Image): - """提取人脸向量(face_recognition + CLIP)""" - arr = np.array(img) - encodings = face_recognition.face_encodings(arr) - if not encodings: - return [] - return [l2_normalize(np.array(e)) for e in encodings] +# ------------------- API 路由(完全模仿 m2m) ------------------- +async def embed(request, params_kw, *params, **kw): + debug(f'{params_kw=}') - # ------------------- Dispatcher ------------------- - def embed_batch(self, inputs): - groups = {"image":[], "video":[], "audio":[], "text":[]} - for item in inputs: - p = Path(item) - ext = item.lower() - if p.exists(): - if any(ext.endswith(e) for e in [".jpg",".jpeg",".png",".bmp",".webp",".heic"]): - groups["image"].append(item) - elif any(ext.endswith(e) for e in [".mp4",".mov",".avi",".mkv"]): - groups["video"].append(item) - elif any(ext.endswith(e) for e in [".mp3",".wav",".flac"]): - groups["audio"].append(item) - else: - groups["text"].append(item) - else: - groups["text"].append(item) - outputs = {} - if groups["image"]: - outputs.update(embed_images(groups["image"])) - if groups["video"]: - outputs.update(embed_videos(groups["video"])) - if groups["audio"]: - outputs.update(embed_audios(groups["audio"])) - if groups["text"]: - outputs.update(embed_texts(groups["text"])) - return outputs + se = ServerEnv() + engine = se.engine + # 从 params_kw 获取参数 + inputs = getattr(params_kw, 'inputs', None) + if not inputs or not isinstance(inputs, list): + raise Exception("`inputs` must be a non-empty list") + + # 调用嵌入 + raw_result = engine.embed_batch(inputs) + + # 构建标准响应 + ret = { + "data": raw_result, + "object": "embedding.result", + "model": engine.model_name + } + return ret + +async def docs(request, *args, **kw): + return helptext + +# ------------------- 服务初始化 ------------------- def init(): - env = ServerEnv() - config = getConfig() - env.mm_model = MM_Embedding(config.model_name) - env.embeded_batch = awaitify(env.mm_model.embeded_batch) -# ------------------- CLI ------------------- -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("inputs", nargs="+", help="file paths or text strings") - parser.add_argument("--out", default="embeddings.npy") + rf = RegisterFunction() + rf.register('embed', embed) + rf.register('docs', docs) + debug("Registered: POST /v1/embed") + +# ------------------- 服务启动 ------------------- +def main(): + parser = argparse.ArgumentParser(prog="CLIP Embedding Service") + parser.add_argument('model_path', nargs='?', help="CLIP model path") + parser.add_argument('-p', '--port', type=int, default=8883) + parser.add_argument('-w', '--workdir', default=os.getcwd()) args = parser.parse_args() - embeddings = embed_batch(args.inputs) - # save dict of name->vector - out_dict = {k:v.tolist() for k,v in embeddings.items()} - np.save(args.out, out_dict) - print(f"Saved embeddings to {args.out}") + config = getConfig() + model_name = args.model_path or config.get("model_name") or CLIP_MODEL_NAME + se = ServerEnv() + se.engine = MM_Embedding(model_name) + + debug(f"Starting embedding service on port {args.port}") + webserver(init, args.workdir, args.port) + +if __name__ == '__main__': + main() \ No newline at end of file