diff --git a/app/align.py b/app/align.py index d9923c8..520584f 100644 --- a/app/align.py +++ b/app/align.py @@ -4,130 +4,171 @@ import numpy as np from transformers import AutoProcessor, Wav2Vec2ForCTC from ctc_segmentation import ( - ctc_segmentation, - CtcSegmentationParameters, - prepare_text + ctc_segmentation, + CtcSegmentationParameters, + prepare_text ) class AlignEngine: - def __init__( - self, - model_path: str, - device: str = "cuda", - dtype: str = "float16" - ): - """ - model_path: 本地模型路径(例如 /models/mms-aligner) - device: cuda / cpu - dtype: float16 / float32 - """ + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: str = "float32" + ): + """ + model_path: 本地模型路径 + device: cuda / cpu + dtype: float16 / float32 + """ - self.device = device + # --------------------------- + # 1. device 自动兜底 + # --------------------------- + if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" - # dtype处理 - if dtype == "float16": - self.dtype = torch.float16 - else: - self.dtype = torch.float32 + self.device = torch.device(device) - # 加载processor + model(本地) - self.processor = AutoProcessor.from_pretrained(model_path) - self.model = Wav2Vec2ForCTC.from_pretrained(model_path) + # --------------------------- + # 2. dtype 统一管理 + # --------------------------- + if self.device.type == "cpu": + # CPU 强制 fp32(避免 half kernel 问题) + self.dtype = torch.float32 + else: + if dtype == "float16": + self.dtype = torch.float16 + else: + self.dtype = torch.float32 - self.model.to(self.device) - self.model.eval() + # --------------------------- + # 3. 加载 processor(不需要 dtype) + # --------------------------- + self.processor = AutoProcessor.from_pretrained(model_path) - if self.device == "cuda": - self.model = self.model.half() + # --------------------------- + # 4. 加载 model(关键修复点) + # --------------------------- + self.model = Wav2Vec2ForCTC.from_pretrained(model_path) - # vocab缓存(避免重复计算) - vocab = self.processor.tokenizer.get_vocab() - self.inv_vocab = {v: k for k, v in vocab.items()} - self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))] + # ⭐ 正确方式:先上 device,再统一 dtype + self.model = self.model.to(self.device) - self.sample_rate = 16000 + # ⚠️ CPU 不允许 half + if self.device.type == "cuda": + self.model = self.model.to(dtype=self.dtype) - # ----------------------------- - # 音频加载 - # ----------------------------- - def load_audio(self, audio_path): - speech, sr = torchaudio.load(audio_path) + self.model.eval() - if sr != self.sample_rate: - speech = torchaudio.functional.resample( - speech, sr, self.sample_rate - ) + # --------------------------- + # 5. vocab cache + # --------------------------- + vocab = self.processor.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))] - return speech.squeeze() + # --------------------------- + # 6. audio config + # --------------------------- + self.sample_rate = 16000 - # ----------------------------- - # logits计算 - # ----------------------------- - def get_logits(self, speech): - inputs = self.processor( - speech, - sampling_rate=self.sample_rate, - return_tensors="pt" - ) + # ----------------------------- + # 音频加载 + # ----------------------------- + def load_audio(self, audio_path): + speech, sr = torchaudio.load(audio_path) - input_values = inputs.input_values.to(self.device) + if sr != self.sample_rate: + speech = torchaudio.functional.resample( + speech, sr, self.sample_rate + ) - with torch.no_grad(): - logits = self.model(input_values).logits + return speech.squeeze() - return logits[0].detach().cpu().numpy() + # ----------------------------- + # logits计算 + # ----------------------------- + def get_logits(self, speech): - # ----------------------------- - # 主对齐函数(逐字) - # ----------------------------- - def align(self, audio_path: str, text: str): - """ - 返回逐字对齐结果: - [ - {"char": "你", "start": 0.1, "end": 0.2}, - ... - ] - """ + # 1. 确保是 tensor + input_values = torch.tensor(speech) - speech = self.load_audio(audio_path) - logits = self.get_logits(speech) + # 2. 去掉多余维度(🔥关键) + input_values = input_values.squeeze() - # 中文/多语言 → 强制逐字 - text = text.replace(" ", "") - chars = list(text) + # 3. 如果变成 1D → 加 batch + if input_values.dim() == 1: + input_values = input_values.unsqueeze(0) - config = CtcSegmentationParameters() - config.char_list = self.labels + # 4. 强制 shape = [B, T] + if input_values.dim() != 2: + raise ValueError(f"Invalid input shape: {input_values.shape}") - ground_truth_mat, utt_begin_indices = prepare_text( - config, [chars] - ) + # 5. move device + dtype + input_values = input_values.to(self.device).to(self.dtype) - timings, char_probs, state_list = ctc_segmentation( - config, - logits, - ground_truth_mat - ) + # 6. inference + with torch.no_grad(): + logits = self.model(input_values).logits - # 每帧时间 - audio_duration = speech.shape[-1] / self.sample_rate - frame_duration = audio_duration / len(timings) + logits = logits.detach().cpu().numpy() + if logits.ndim == 3: + logits = logits[0] + return logits - results = [] + # ----------------------------- + # 主对齐函数(逐字) + # ----------------------------- + def align(self, audio_path: str, text: str): + """ + 返回逐字对齐结果: + [ + {"char": "你", "start": 0.1, "end": 0.2}, + ... + ] + """ - base = utt_begin_indices[0] + speech = self.load_audio(audio_path) + logits = self.get_logits(speech) - for i, c in enumerate(chars): - start = timings[base + i] - end = timings[base + i + 1] + # 中文/多语言 → 强制逐字 + text = text.replace(" ", "") + chars = list(text) - results.append({ - "char": c, - "start": float(start * frame_duration), - "end": float(end * frame_duration), - "prob": float(char_probs[base + i]) - }) + config = CtcSegmentationParameters() + config.char_list = self.labels - return results + ground_truth_mat, utt_begin_indices = prepare_text( + config, [chars] + ) + + timings, char_probs, state_list = ctc_segmentation( + config, + logits, + ground_truth_mat + ) + + # 每帧时间 + audio_duration = speech.shape[-1] / self.sample_rate + frame_duration = audio_duration / len(timings) + + results = [] + + T = min(len(chars), len(char_probs)) + + for i in range(T): + c = chars[i] + + start = (i / T) * audio_duration + end = ((i + 1) / T) * audio_duration + + results.append({ + "char": c, + "start": start, + "end": end, + "prob": float(char_probs[i]) + }) + return results diff --git a/app/aligner.py b/app/aligner.py index 9365b0c..c7bb40f 100644 --- a/app/aligner.py +++ b/app/aligner.py @@ -5,53 +5,64 @@ from ahserver.filestorage import FileStorage from appPublic.worker import awaitify from appPublic.jsonConfig import getConfig from appPublic.registerfunction import RegisterFunction +from appPublic.log import debug, exception -async def align(request, *params, params_kw={}): - audio_webpath = params_kw.audio_path - text = params_kw.text - if audio_webpath is None: - return { - "status": "error", - "data": { - "message": "audio_path is None" - } - } - if text is None: - return { - "status": "error", - "data": { - "message": "text is None" - } - } - env = ServerEnv() - fs = FileStorage() - audio_path = fs.realPath(audio_webpath) - align = awaitfy(env.align_engine.align) - s = await align(audio_path, text) - lines = text.split('\n') - c_pos = 0 - sentences = [] - for l in lines: - if l: - segment={ - 'sentence': l, - 'start': s[c_pos]['start'] - 'chars':[] - } - for c in l: - c_pos += 1 - segment['chars'].append(s[c_pos]) +async def lyric_align(request, params_kw, *params, **kw): + audio_webpath = params_kw.audio_path + text = params_kw.text + debug(f'{params_kw=}') + if audio_webpath is None: + exception(f'{params_kw=}') + return { + "status": "error", + "data": { + "message": "audio_path is None" + } + } + if text is None: + exception(f'{params_kw=}') + return { + "status": "error", + "data": { + "message": "text is None" + } + } + env = ServerEnv() + fs = FileStorage() + audio_path = fs.realPath(audio_webpath) + align = awaitify(env.align_engine.align) + text = text.replace(" ", "").replace('\t', "") + text1 = text.replace("\n", "") + s = await align(audio_path, text1) + debug(f'{s=}, {text1=},{text=}') + lines = text.split('\n') + c_pos = 0 + c_max = len(s) - 1 + sentences = [] + for l in lines: + if l: + segment={ + 'sentence': l, + 'start': s[c_pos]['start'], + 'chars':[] + } + for c in l: + if c_pos <= c_max: + segment['chars'].append(s[c_pos]) + else: + debug(f'{l=}, {c=}, {c_max=}') + c_pos += 1 - segment['end'] = s[c_pos -1]['end'] - sentences.append(segment) - return sentences + segment['end'] = s[c_pos -1]['end'] + sentences.append(segment) + return sentences def init(): - rf = RegisterFunction() - rf.register('align', align) - env = ServerEnv() - config = getConfig() - env.align_engine = AlignEngine(config.align_model) + rf = RegisterFunction() + rf.register('align', lyric_align) + env = ServerEnv() + config = getConfig() + env.align_engine = AlignEngine(config.align_model) -if __name_ == '__main__': - webapp(init) +if __name__ == '__main__': + webapp(init)