aligner/app/align.py
2026-04-17 19:00:07 +08:00

175 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torchaudio
import numpy as np
from transformers import AutoProcessor, Wav2Vec2ForCTC
from ctc_segmentation import (
ctc_segmentation,
CtcSegmentationParameters,
prepare_text
)
class AlignEngine:
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: str = "float32"
):
"""
model_path: 本地模型路径
device: cuda / cpu
dtype: float16 / float32
"""
# ---------------------------
# 1. device 自动兜底
# ---------------------------
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
self.device = torch.device(device)
# ---------------------------
# 2. dtype 统一管理
# ---------------------------
if self.device.type == "cpu":
# CPU 强制 fp32避免 half kernel 问题)
self.dtype = torch.float32
else:
if dtype == "float16":
self.dtype = torch.float16
else:
self.dtype = torch.float32
# ---------------------------
# 3. 加载 processor不需要 dtype
# ---------------------------
self.processor = AutoProcessor.from_pretrained(model_path)
# ---------------------------
# 4. 加载 model关键修复点
# ---------------------------
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
# ⭐ 正确方式:先上 device再统一 dtype
self.model = self.model.to(self.device)
# ⚠️ CPU 不允许 half
if self.device.type == "cuda":
self.model = self.model.to(dtype=self.dtype)
self.model.eval()
# ---------------------------
# 5. vocab cache
# ---------------------------
vocab = self.processor.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
# ---------------------------
# 6. audio config
# ---------------------------
self.sample_rate = 16000
# -----------------------------
# 音频加载
# -----------------------------
def load_audio(self, audio_path):
speech, sr = torchaudio.load(audio_path)
if sr != self.sample_rate:
speech = torchaudio.functional.resample(
speech, sr, self.sample_rate
)
return speech.squeeze()
# -----------------------------
# logits计算
# -----------------------------
def get_logits(self, speech):
# 1. 确保是 tensor
input_values = torch.tensor(speech)
# 2. 去掉多余维度(🔥关键)
input_values = input_values.squeeze()
# 3. 如果变成 1D → 加 batch
if input_values.dim() == 1:
input_values = input_values.unsqueeze(0)
# 4. 强制 shape = [B, T]
if input_values.dim() != 2:
raise ValueError(f"Invalid input shape: {input_values.shape}")
# 5. move device + dtype
input_values = input_values.to(self.device).to(self.dtype)
# 6. inference
with torch.no_grad():
logits = self.model(input_values).logits
logits = logits.detach().cpu().numpy()
if logits.ndim == 3:
logits = logits[0]
return logits
# -----------------------------
# 主对齐函数(逐字)
# -----------------------------
def align(self, audio_path: str, text: str):
"""
返回逐字对齐结果:
[
{"char": "", "start": 0.1, "end": 0.2},
...
]
"""
speech = self.load_audio(audio_path)
logits = self.get_logits(speech)
# 中文/多语言 → 强制逐字
text = text.replace(" ", "")
chars = list(text)
config = CtcSegmentationParameters()
config.char_list = self.labels
ground_truth_mat, utt_begin_indices = prepare_text(
config, [chars]
)
timings, char_probs, state_list = ctc_segmentation(
config,
logits,
ground_truth_mat
)
# 每帧时间
audio_duration = speech.shape[-1] / self.sample_rate
frame_duration = audio_duration / len(timings)
results = []
T = min(len(chars), len(char_probs))
for i in range(T):
c = chars[i]
start = (i / T) * audio_duration
end = ((i + 1) / T) * audio_duration
results.append({
"char": c,
"start": start,
"end": end,
"prob": float(char_probs[i])
})
return results