import torch import torchaudio import numpy as np from transformers import AutoProcessor, Wav2Vec2ForCTC from ctc_segmentation import ( ctc_segmentation, CtcSegmentationParameters, prepare_text ) class AlignEngine: def __init__( self, model_path: str, device: str = "cuda", dtype: str = "float16" ): """ model_path: 本地模型路径(例如 /models/mms-aligner) device: cuda / cpu dtype: float16 / float32 """ self.device = device # dtype处理 if dtype == "float16": self.dtype = torch.float16 else: self.dtype = torch.float32 # 加载processor + model(本地) self.processor = AutoProcessor.from_pretrained(model_path) self.model = Wav2Vec2ForCTC.from_pretrained(model_path) self.model.to(self.device) self.model.eval() if self.device == "cuda": self.model = self.model.half() # vocab缓存(避免重复计算) vocab = self.processor.tokenizer.get_vocab() self.inv_vocab = {v: k for k, v in vocab.items()} self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))] self.sample_rate = 16000 # ----------------------------- # 音频加载 # ----------------------------- def load_audio(self, audio_path): speech, sr = torchaudio.load(audio_path) if sr != self.sample_rate: speech = torchaudio.functional.resample( speech, sr, self.sample_rate ) return speech.squeeze() # ----------------------------- # logits计算 # ----------------------------- def get_logits(self, speech): inputs = self.processor( speech, sampling_rate=self.sample_rate, return_tensors="pt" ) input_values = inputs.input_values.to(self.device) with torch.no_grad(): logits = self.model(input_values).logits return logits[0].detach().cpu().numpy() # ----------------------------- # 主对齐函数(逐字) # ----------------------------- def align(self, audio_path: str, text: str): """ 返回逐字对齐结果: [ {"char": "你", "start": 0.1, "end": 0.2}, ... ] """ speech = self.load_audio(audio_path) logits = self.get_logits(speech) # 中文/多语言 → 强制逐字 text = text.replace(" ", "") chars = list(text) config = CtcSegmentationParameters() config.char_list = self.labels ground_truth_mat, utt_begin_indices = prepare_text( config, [chars] ) timings, char_probs, state_list = ctc_segmentation( config, logits, ground_truth_mat ) # 每帧时间 audio_duration = speech.shape[-1] / self.sample_rate frame_duration = audio_duration / len(timings) results = [] base = utt_begin_indices[0] for i, c in enumerate(chars): start = timings[base + i] end = timings[base + i + 1] results.append({ "char": c, "start": float(start * frame_duration), "end": float(end * frame_duration), "prob": float(char_probs[base + i]) }) return results