import torch import torchaudio import numpy as np from transformers import AutoProcessor, Wav2Vec2ForCTC from ctc_segmentation import ( ctc_segmentation, CtcSegmentationParameters, prepare_text ) class AlignEngine: def __init__( self, model_path: str, device: str = "cuda", dtype: str = "float32" ): """ model_path: 本地模型路径 device: cuda / cpu dtype: float16 / float32 """ # --------------------------- # 1. device 自动兜底 # --------------------------- if device == "cuda" and not torch.cuda.is_available(): device = "cpu" self.device = torch.device(device) # --------------------------- # 2. dtype 统一管理 # --------------------------- if self.device.type == "cpu": # CPU 强制 fp32(避免 half kernel 问题) self.dtype = torch.float32 else: if dtype == "float16": self.dtype = torch.float16 else: self.dtype = torch.float32 # --------------------------- # 3. 加载 processor(不需要 dtype) # --------------------------- self.processor = AutoProcessor.from_pretrained(model_path) # --------------------------- # 4. 加载 model(关键修复点) # --------------------------- self.model = Wav2Vec2ForCTC.from_pretrained(model_path) # ⭐ 正确方式:先上 device,再统一 dtype self.model = self.model.to(self.device) # ⚠️ CPU 不允许 half if self.device.type == "cuda": self.model = self.model.to(dtype=self.dtype) self.model.eval() # --------------------------- # 5. vocab cache # --------------------------- vocab = self.processor.tokenizer.get_vocab() self.inv_vocab = {v: k for k, v in vocab.items()} self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))] # --------------------------- # 6. audio config # --------------------------- self.sample_rate = 16000 # ----------------------------- # 音频加载 # ----------------------------- def load_audio(self, audio_path): speech, sr = torchaudio.load(audio_path) if sr != self.sample_rate: speech = torchaudio.functional.resample( speech, sr, self.sample_rate ) return speech.squeeze() # ----------------------------- # logits计算 # ----------------------------- def get_logits(self, speech): # 1. 确保是 tensor input_values = torch.tensor(speech) # 2. 去掉多余维度(🔥关键) input_values = input_values.squeeze() # 3. 如果变成 1D → 加 batch if input_values.dim() == 1: input_values = input_values.unsqueeze(0) # 4. 强制 shape = [B, T] if input_values.dim() != 2: raise ValueError(f"Invalid input shape: {input_values.shape}") # 5. move device + dtype input_values = input_values.to(self.device).to(self.dtype) # 6. inference with torch.no_grad(): logits = self.model(input_values).logits logits = logits.detach().cpu().numpy() if logits.ndim == 3: logits = logits[0] return logits # ----------------------------- # 主对齐函数(逐字) # ----------------------------- def align(self, audio_path: str, text: str): """ 返回逐字对齐结果: [ {"char": "你", "start": 0.1, "end": 0.2}, ... ] """ speech = self.load_audio(audio_path) logits = self.get_logits(speech) # 中文/多语言 → 强制逐字 text = text.replace(" ", "") chars = list(text) config = CtcSegmentationParameters() config.char_list = self.labels ground_truth_mat, utt_begin_indices = prepare_text( config, [chars] ) timings, char_probs, state_list = ctc_segmentation( config, logits, ground_truth_mat ) # 每帧时间 audio_duration = speech.shape[-1] / self.sample_rate frame_duration = audio_duration / len(timings) results = [] T = min(len(chars), len(char_probs)) for i in range(T): c = chars[i] start = (i / T) * audio_duration end = ((i + 1) / T) * audio_duration results.append({ "char": c, "start": start, "end": end, "prob": float(char_probs[i]) }) return results