This commit is contained in:
yumoqing 2026-04-17 19:00:07 +08:00
parent 29ace2ac85
commit 6f6e32e80f
2 changed files with 193 additions and 141 deletions

View File

@ -4,130 +4,171 @@ import numpy as np
from transformers import AutoProcessor, Wav2Vec2ForCTC
from ctc_segmentation import (
ctc_segmentation,
CtcSegmentationParameters,
prepare_text
ctc_segmentation,
CtcSegmentationParameters,
prepare_text
)
class AlignEngine:
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: str = "float16"
):
"""
model_path: 本地模型路径例如 /models/mms-aligner
device: cuda / cpu
dtype: float16 / float32
"""
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: str = "float32"
):
"""
model_path: 本地模型路径
device: cuda / cpu
dtype: float16 / float32
"""
self.device = device
# ---------------------------
# 1. device 自动兜底
# ---------------------------
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
# dtype处理
if dtype == "float16":
self.dtype = torch.float16
else:
self.dtype = torch.float32
self.device = torch.device(device)
# 加载processor + model本地
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
# ---------------------------
# 2. dtype 统一管理
# ---------------------------
if self.device.type == "cpu":
# CPU 强制 fp32避免 half kernel 问题)
self.dtype = torch.float32
else:
if dtype == "float16":
self.dtype = torch.float16
else:
self.dtype = torch.float32
self.model.to(self.device)
self.model.eval()
# ---------------------------
# 3. 加载 processor不需要 dtype
# ---------------------------
self.processor = AutoProcessor.from_pretrained(model_path)
if self.device == "cuda":
self.model = self.model.half()
# ---------------------------
# 4. 加载 model关键修复点
# ---------------------------
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
# vocab缓存避免重复计算
vocab = self.processor.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
# ⭐ 正确方式:先上 device再统一 dtype
self.model = self.model.to(self.device)
self.sample_rate = 16000
# ⚠️ CPU 不允许 half
if self.device.type == "cuda":
self.model = self.model.to(dtype=self.dtype)
# -----------------------------
# 音频加载
# -----------------------------
def load_audio(self, audio_path):
speech, sr = torchaudio.load(audio_path)
self.model.eval()
if sr != self.sample_rate:
speech = torchaudio.functional.resample(
speech, sr, self.sample_rate
)
# ---------------------------
# 5. vocab cache
# ---------------------------
vocab = self.processor.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
return speech.squeeze()
# ---------------------------
# 6. audio config
# ---------------------------
self.sample_rate = 16000
# -----------------------------
# logits计算
# -----------------------------
def get_logits(self, speech):
inputs = self.processor(
speech,
sampling_rate=self.sample_rate,
return_tensors="pt"
)
# -----------------------------
# 音频加载
# -----------------------------
def load_audio(self, audio_path):
speech, sr = torchaudio.load(audio_path)
input_values = inputs.input_values.to(self.device)
if sr != self.sample_rate:
speech = torchaudio.functional.resample(
speech, sr, self.sample_rate
)
with torch.no_grad():
logits = self.model(input_values).logits
return speech.squeeze()
return logits[0].detach().cpu().numpy()
# -----------------------------
# logits计算
# -----------------------------
def get_logits(self, speech):
# -----------------------------
# 主对齐函数(逐字)
# -----------------------------
def align(self, audio_path: str, text: str):
"""
返回逐字对齐结果
[
{"char": "", "start": 0.1, "end": 0.2},
...
]
"""
# 1. 确保是 tensor
input_values = torch.tensor(speech)
speech = self.load_audio(audio_path)
logits = self.get_logits(speech)
# 2. 去掉多余维度(🔥关键)
input_values = input_values.squeeze()
# 中文/多语言 → 强制逐字
text = text.replace(" ", "")
chars = list(text)
# 3. 如果变成 1D → 加 batch
if input_values.dim() == 1:
input_values = input_values.unsqueeze(0)
config = CtcSegmentationParameters()
config.char_list = self.labels
# 4. 强制 shape = [B, T]
if input_values.dim() != 2:
raise ValueError(f"Invalid input shape: {input_values.shape}")
ground_truth_mat, utt_begin_indices = prepare_text(
config, [chars]
)
# 5. move device + dtype
input_values = input_values.to(self.device).to(self.dtype)
timings, char_probs, state_list = ctc_segmentation(
config,
logits,
ground_truth_mat
)
# 6. inference
with torch.no_grad():
logits = self.model(input_values).logits
# 每帧时间
audio_duration = speech.shape[-1] / self.sample_rate
frame_duration = audio_duration / len(timings)
logits = logits.detach().cpu().numpy()
if logits.ndim == 3:
logits = logits[0]
return logits
results = []
# -----------------------------
# 主对齐函数(逐字)
# -----------------------------
def align(self, audio_path: str, text: str):
"""
返回逐字对齐结果
[
{"char": "", "start": 0.1, "end": 0.2},
...
]
"""
base = utt_begin_indices[0]
speech = self.load_audio(audio_path)
logits = self.get_logits(speech)
for i, c in enumerate(chars):
start = timings[base + i]
end = timings[base + i + 1]
# 中文/多语言 → 强制逐字
text = text.replace(" ", "")
chars = list(text)
results.append({
"char": c,
"start": float(start * frame_duration),
"end": float(end * frame_duration),
"prob": float(char_probs[base + i])
})
config = CtcSegmentationParameters()
config.char_list = self.labels
return results
ground_truth_mat, utt_begin_indices = prepare_text(
config, [chars]
)
timings, char_probs, state_list = ctc_segmentation(
config,
logits,
ground_truth_mat
)
# 每帧时间
audio_duration = speech.shape[-1] / self.sample_rate
frame_duration = audio_duration / len(timings)
results = []
T = min(len(chars), len(char_probs))
for i in range(T):
c = chars[i]
start = (i / T) * audio_duration
end = ((i + 1) / T) * audio_duration
results.append({
"char": c,
"start": start,
"end": end,
"prob": float(char_probs[i])
})
return results

View File

@ -5,53 +5,64 @@ from ahserver.filestorage import FileStorage
from appPublic.worker import awaitify
from appPublic.jsonConfig import getConfig
from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception
async def align(request, *params, params_kw={}):
audio_webpath = params_kw.audio_path
text = params_kw.text
if audio_webpath is None:
return {
"status": "error",
"data": {
"message": "audio_path is None"
}
}
if text is None:
return {
"status": "error",
"data": {
"message": "text is None"
}
}
env = ServerEnv()
fs = FileStorage()
audio_path = fs.realPath(audio_webpath)
align = awaitfy(env.align_engine.align)
s = await align(audio_path, text)
lines = text.split('\n')
c_pos = 0
sentences = []
for l in lines:
if l:
segment={
'sentence': l,
'start': s[c_pos]['start']
'chars':[]
}
for c in l:
c_pos += 1
segment['chars'].append(s[c_pos])
async def lyric_align(request, params_kw, *params, **kw):
audio_webpath = params_kw.audio_path
text = params_kw.text
debug(f'{params_kw=}')
if audio_webpath is None:
exception(f'{params_kw=}')
return {
"status": "error",
"data": {
"message": "audio_path is None"
}
}
if text is None:
exception(f'{params_kw=}')
return {
"status": "error",
"data": {
"message": "text is None"
}
}
env = ServerEnv()
fs = FileStorage()
audio_path = fs.realPath(audio_webpath)
align = awaitify(env.align_engine.align)
text = text.replace(" ", "").replace('\t', "")
text1 = text.replace("\n", "")
s = await align(audio_path, text1)
debug(f'{s=}, {text1=},{text=}')
lines = text.split('\n')
c_pos = 0
c_max = len(s) - 1
sentences = []
for l in lines:
if l:
segment={
'sentence': l,
'start': s[c_pos]['start'],
'chars':[]
}
for c in l:
if c_pos <= c_max:
segment['chars'].append(s[c_pos])
else:
debug(f'{l=}, {c=}, {c_max=}')
c_pos += 1
segment['end'] = s[c_pos -1]['end']
sentences.append(segment)
return sentences
segment['end'] = s[c_pos -1]['end']
sentences.append(segment)
return sentences
def init():
rf = RegisterFunction()
rf.register('align', align)
env = ServerEnv()
config = getConfig()
env.align_engine = AlignEngine(config.align_model)
rf = RegisterFunction()
rf.register('align', lyric_align)
env = ServerEnv()
config = getConfig()
env.align_engine = AlignEngine(config.align_model)
if __name_ == '__main__':
webapp(init)
if __name__ == '__main__':
webapp(init)