175 lines
3.9 KiB
Python
175 lines
3.9 KiB
Python
import torch
|
||
import torchaudio
|
||
import numpy as np
|
||
|
||
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
||
from ctc_segmentation import (
|
||
ctc_segmentation,
|
||
CtcSegmentationParameters,
|
||
prepare_text
|
||
)
|
||
|
||
|
||
class AlignEngine:
|
||
|
||
def __init__(
|
||
self,
|
||
model_path: str,
|
||
device: str = "cuda",
|
||
dtype: str = "float32"
|
||
):
|
||
"""
|
||
model_path: 本地模型路径
|
||
device: cuda / cpu
|
||
dtype: float16 / float32
|
||
"""
|
||
|
||
# ---------------------------
|
||
# 1. device 自动兜底
|
||
# ---------------------------
|
||
if device == "cuda" and not torch.cuda.is_available():
|
||
device = "cpu"
|
||
|
||
self.device = torch.device(device)
|
||
|
||
# ---------------------------
|
||
# 2. dtype 统一管理
|
||
# ---------------------------
|
||
if self.device.type == "cpu":
|
||
# CPU 强制 fp32(避免 half kernel 问题)
|
||
self.dtype = torch.float32
|
||
else:
|
||
if dtype == "float16":
|
||
self.dtype = torch.float16
|
||
else:
|
||
self.dtype = torch.float32
|
||
|
||
# ---------------------------
|
||
# 3. 加载 processor(不需要 dtype)
|
||
# ---------------------------
|
||
self.processor = AutoProcessor.from_pretrained(model_path)
|
||
|
||
# ---------------------------
|
||
# 4. 加载 model(关键修复点)
|
||
# ---------------------------
|
||
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
|
||
|
||
# ⭐ 正确方式:先上 device,再统一 dtype
|
||
self.model = self.model.to(self.device)
|
||
|
||
# ⚠️ CPU 不允许 half
|
||
if self.device.type == "cuda":
|
||
self.model = self.model.to(dtype=self.dtype)
|
||
|
||
self.model.eval()
|
||
|
||
# ---------------------------
|
||
# 5. vocab cache
|
||
# ---------------------------
|
||
vocab = self.processor.tokenizer.get_vocab()
|
||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
|
||
|
||
# ---------------------------
|
||
# 6. audio config
|
||
# ---------------------------
|
||
self.sample_rate = 16000
|
||
|
||
# -----------------------------
|
||
# 音频加载
|
||
# -----------------------------
|
||
def load_audio(self, audio_path):
|
||
speech, sr = torchaudio.load(audio_path)
|
||
|
||
if sr != self.sample_rate:
|
||
speech = torchaudio.functional.resample(
|
||
speech, sr, self.sample_rate
|
||
)
|
||
|
||
return speech.squeeze()
|
||
|
||
# -----------------------------
|
||
# logits计算
|
||
# -----------------------------
|
||
def get_logits(self, speech):
|
||
|
||
# 1. 确保是 tensor
|
||
input_values = torch.tensor(speech)
|
||
|
||
# 2. 去掉多余维度(🔥关键)
|
||
input_values = input_values.squeeze()
|
||
|
||
# 3. 如果变成 1D → 加 batch
|
||
if input_values.dim() == 1:
|
||
input_values = input_values.unsqueeze(0)
|
||
|
||
# 4. 强制 shape = [B, T]
|
||
if input_values.dim() != 2:
|
||
raise ValueError(f"Invalid input shape: {input_values.shape}")
|
||
|
||
# 5. move device + dtype
|
||
input_values = input_values.to(self.device).to(self.dtype)
|
||
|
||
# 6. inference
|
||
with torch.no_grad():
|
||
logits = self.model(input_values).logits
|
||
|
||
logits = logits.detach().cpu().numpy()
|
||
if logits.ndim == 3:
|
||
logits = logits[0]
|
||
return logits
|
||
|
||
# -----------------------------
|
||
# 主对齐函数(逐字)
|
||
# -----------------------------
|
||
def align(self, audio_path: str, text: str):
|
||
"""
|
||
返回逐字对齐结果:
|
||
[
|
||
{"char": "你", "start": 0.1, "end": 0.2},
|
||
...
|
||
]
|
||
"""
|
||
|
||
speech = self.load_audio(audio_path)
|
||
logits = self.get_logits(speech)
|
||
|
||
# 中文/多语言 → 强制逐字
|
||
text = text.replace(" ", "")
|
||
chars = list(text)
|
||
|
||
config = CtcSegmentationParameters()
|
||
config.char_list = self.labels
|
||
|
||
ground_truth_mat, utt_begin_indices = prepare_text(
|
||
config, [chars]
|
||
)
|
||
|
||
timings, char_probs, state_list = ctc_segmentation(
|
||
config,
|
||
logits,
|
||
ground_truth_mat
|
||
)
|
||
|
||
# 每帧时间
|
||
audio_duration = speech.shape[-1] / self.sample_rate
|
||
frame_duration = audio_duration / len(timings)
|
||
|
||
results = []
|
||
|
||
T = min(len(chars), len(char_probs))
|
||
|
||
for i in range(T):
|
||
c = chars[i]
|
||
|
||
start = (i / T) * audio_duration
|
||
end = ((i + 1) / T) * audio_duration
|
||
|
||
results.append({
|
||
"char": c,
|
||
"start": start,
|
||
"end": end,
|
||
"prob": float(char_probs[i])
|
||
})
|
||
return results
|