aligner/app/align.py
2026-04-17 15:16:12 +08:00

134 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torchaudio
import numpy as np
from transformers import AutoProcessor, Wav2Vec2ForCTC
from ctc_segmentation import (
ctc_segmentation,
CtcSegmentationParameters,
prepare_text
)
class AlignEngine:
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: str = "float16"
):
"""
model_path: 本地模型路径(例如 /models/mms-aligner
device: cuda / cpu
dtype: float16 / float32
"""
self.device = device
# dtype处理
if dtype == "float16":
self.dtype = torch.float16
else:
self.dtype = torch.float32
# 加载processor + model本地
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
if self.device == "cuda":
self.model = self.model.half()
# vocab缓存避免重复计算
vocab = self.processor.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
self.sample_rate = 16000
# -----------------------------
# 音频加载
# -----------------------------
def load_audio(self, audio_path):
speech, sr = torchaudio.load(audio_path)
if sr != self.sample_rate:
speech = torchaudio.functional.resample(
speech, sr, self.sample_rate
)
return speech.squeeze()
# -----------------------------
# logits计算
# -----------------------------
def get_logits(self, speech):
inputs = self.processor(
speech,
sampling_rate=self.sample_rate,
return_tensors="pt"
)
input_values = inputs.input_values.to(self.device)
with torch.no_grad():
logits = self.model(input_values).logits
return logits[0].detach().cpu().numpy()
# -----------------------------
# 主对齐函数(逐字)
# -----------------------------
def align(self, audio_path: str, text: str):
"""
返回逐字对齐结果:
[
{"char": "", "start": 0.1, "end": 0.2},
...
]
"""
speech = self.load_audio(audio_path)
logits = self.get_logits(speech)
# 中文/多语言 → 强制逐字
text = text.replace(" ", "")
chars = list(text)
config = CtcSegmentationParameters()
config.char_list = self.labels
ground_truth_mat, utt_begin_indices = prepare_text(
config, [chars]
)
timings, char_probs, state_list = ctc_segmentation(
config,
logits,
ground_truth_mat
)
# 每帧时间
audio_duration = speech.shape[-1] / self.sample_rate
frame_duration = audio_duration / len(timings)
results = []
base = utt_begin_indices[0]
for i, c in enumerate(chars):
start = timings[base + i]
end = timings[base + i + 1]
results.append({
"char": c,
"start": float(start * frame_duration),
"end": float(end * frame_duration),
"prob": float(char_probs[base + i])
})
return results