This commit is contained in:
yumoqing 2026-04-17 19:00:07 +08:00
parent 29ace2ac85
commit 6f6e32e80f
2 changed files with 193 additions and 141 deletions

View File

@ -4,130 +4,171 @@ import numpy as np
from transformers import AutoProcessor, Wav2Vec2ForCTC from transformers import AutoProcessor, Wav2Vec2ForCTC
from ctc_segmentation import ( from ctc_segmentation import (
ctc_segmentation, ctc_segmentation,
CtcSegmentationParameters, CtcSegmentationParameters,
prepare_text prepare_text
) )
class AlignEngine: class AlignEngine:
def __init__( def __init__(
self, self,
model_path: str, model_path: str,
device: str = "cuda", device: str = "cuda",
dtype: str = "float16" dtype: str = "float32"
): ):
""" """
model_path: 本地模型路径例如 /models/mms-aligner model_path: 本地模型路径
device: cuda / cpu device: cuda / cpu
dtype: float16 / float32 dtype: float16 / float32
""" """
self.device = device # ---------------------------
# 1. device 自动兜底
# ---------------------------
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
# dtype处理 self.device = torch.device(device)
if dtype == "float16":
self.dtype = torch.float16
else:
self.dtype = torch.float32
# 加载processor + model本地 # ---------------------------
self.processor = AutoProcessor.from_pretrained(model_path) # 2. dtype 统一管理
self.model = Wav2Vec2ForCTC.from_pretrained(model_path) # ---------------------------
if self.device.type == "cpu":
# CPU 强制 fp32避免 half kernel 问题)
self.dtype = torch.float32
else:
if dtype == "float16":
self.dtype = torch.float16
else:
self.dtype = torch.float32
self.model.to(self.device) # ---------------------------
self.model.eval() # 3. 加载 processor不需要 dtype
# ---------------------------
self.processor = AutoProcessor.from_pretrained(model_path)
if self.device == "cuda": # ---------------------------
self.model = self.model.half() # 4. 加载 model关键修复点
# ---------------------------
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
# vocab缓存避免重复计算 # ⭐ 正确方式:先上 device再统一 dtype
vocab = self.processor.tokenizer.get_vocab() self.model = self.model.to(self.device)
self.inv_vocab = {v: k for k, v in vocab.items()}
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
self.sample_rate = 16000 # ⚠️ CPU 不允许 half
if self.device.type == "cuda":
self.model = self.model.to(dtype=self.dtype)
# ----------------------------- self.model.eval()
# 音频加载
# -----------------------------
def load_audio(self, audio_path):
speech, sr = torchaudio.load(audio_path)
if sr != self.sample_rate: # ---------------------------
speech = torchaudio.functional.resample( # 5. vocab cache
speech, sr, self.sample_rate # ---------------------------
) vocab = self.processor.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
return speech.squeeze() # ---------------------------
# 6. audio config
# ---------------------------
self.sample_rate = 16000
# ----------------------------- # -----------------------------
# logits计算 # 音频加载
# ----------------------------- # -----------------------------
def get_logits(self, speech): def load_audio(self, audio_path):
inputs = self.processor( speech, sr = torchaudio.load(audio_path)
speech,
sampling_rate=self.sample_rate,
return_tensors="pt"
)
input_values = inputs.input_values.to(self.device) if sr != self.sample_rate:
speech = torchaudio.functional.resample(
speech, sr, self.sample_rate
)
with torch.no_grad(): return speech.squeeze()
logits = self.model(input_values).logits
return logits[0].detach().cpu().numpy() # -----------------------------
# logits计算
# -----------------------------
def get_logits(self, speech):
# ----------------------------- # 1. 确保是 tensor
# 主对齐函数(逐字) input_values = torch.tensor(speech)
# -----------------------------
def align(self, audio_path: str, text: str):
"""
返回逐字对齐结果
[
{"char": "", "start": 0.1, "end": 0.2},
...
]
"""
speech = self.load_audio(audio_path) # 2. 去掉多余维度(🔥关键)
logits = self.get_logits(speech) input_values = input_values.squeeze()
# 中文/多语言 → 强制逐字 # 3. 如果变成 1D → 加 batch
text = text.replace(" ", "") if input_values.dim() == 1:
chars = list(text) input_values = input_values.unsqueeze(0)
config = CtcSegmentationParameters() # 4. 强制 shape = [B, T]
config.char_list = self.labels if input_values.dim() != 2:
raise ValueError(f"Invalid input shape: {input_values.shape}")
ground_truth_mat, utt_begin_indices = prepare_text( # 5. move device + dtype
config, [chars] input_values = input_values.to(self.device).to(self.dtype)
)
timings, char_probs, state_list = ctc_segmentation( # 6. inference
config, with torch.no_grad():
logits, logits = self.model(input_values).logits
ground_truth_mat
)
# 每帧时间 logits = logits.detach().cpu().numpy()
audio_duration = speech.shape[-1] / self.sample_rate if logits.ndim == 3:
frame_duration = audio_duration / len(timings) logits = logits[0]
return logits
results = [] # -----------------------------
# 主对齐函数(逐字)
# -----------------------------
def align(self, audio_path: str, text: str):
"""
返回逐字对齐结果
[
{"char": "", "start": 0.1, "end": 0.2},
...
]
"""
base = utt_begin_indices[0] speech = self.load_audio(audio_path)
logits = self.get_logits(speech)
for i, c in enumerate(chars): # 中文/多语言 → 强制逐字
start = timings[base + i] text = text.replace(" ", "")
end = timings[base + i + 1] chars = list(text)
results.append({ config = CtcSegmentationParameters()
"char": c, config.char_list = self.labels
"start": float(start * frame_duration),
"end": float(end * frame_duration),
"prob": float(char_probs[base + i])
})
return results ground_truth_mat, utt_begin_indices = prepare_text(
config, [chars]
)
timings, char_probs, state_list = ctc_segmentation(
config,
logits,
ground_truth_mat
)
# 每帧时间
audio_duration = speech.shape[-1] / self.sample_rate
frame_duration = audio_duration / len(timings)
results = []
T = min(len(chars), len(char_probs))
for i in range(T):
c = chars[i]
start = (i / T) * audio_duration
end = ((i + 1) / T) * audio_duration
results.append({
"char": c,
"start": start,
"end": end,
"prob": float(char_probs[i])
})
return results

View File

@ -5,53 +5,64 @@ from ahserver.filestorage import FileStorage
from appPublic.worker import awaitify from appPublic.worker import awaitify
from appPublic.jsonConfig import getConfig from appPublic.jsonConfig import getConfig
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception
async def align(request, *params, params_kw={}): async def lyric_align(request, params_kw, *params, **kw):
audio_webpath = params_kw.audio_path audio_webpath = params_kw.audio_path
text = params_kw.text text = params_kw.text
if audio_webpath is None: debug(f'{params_kw=}')
return { if audio_webpath is None:
"status": "error", exception(f'{params_kw=}')
"data": { return {
"message": "audio_path is None" "status": "error",
} "data": {
} "message": "audio_path is None"
if text is None: }
return { }
"status": "error", if text is None:
"data": { exception(f'{params_kw=}')
"message": "text is None" return {
} "status": "error",
} "data": {
env = ServerEnv() "message": "text is None"
fs = FileStorage() }
audio_path = fs.realPath(audio_webpath) }
align = awaitfy(env.align_engine.align) env = ServerEnv()
s = await align(audio_path, text) fs = FileStorage()
lines = text.split('\n') audio_path = fs.realPath(audio_webpath)
c_pos = 0 align = awaitify(env.align_engine.align)
sentences = [] text = text.replace(" ", "").replace('\t', "")
for l in lines: text1 = text.replace("\n", "")
if l: s = await align(audio_path, text1)
segment={ debug(f'{s=}, {text1=},{text=}')
'sentence': l, lines = text.split('\n')
'start': s[c_pos]['start'] c_pos = 0
'chars':[] c_max = len(s) - 1
} sentences = []
for c in l: for l in lines:
c_pos += 1 if l:
segment['chars'].append(s[c_pos]) segment={
'sentence': l,
'start': s[c_pos]['start'],
'chars':[]
}
for c in l:
if c_pos <= c_max:
segment['chars'].append(s[c_pos])
else:
debug(f'{l=}, {c=}, {c_max=}')
c_pos += 1
segment['end'] = s[c_pos -1]['end'] segment['end'] = s[c_pos -1]['end']
sentences.append(segment) sentences.append(segment)
return sentences return sentences
def init(): def init():
rf = RegisterFunction() rf = RegisterFunction()
rf.register('align', align) rf.register('align', lyric_align)
env = ServerEnv() env = ServerEnv()
config = getConfig() config = getConfig()
env.align_engine = AlignEngine(config.align_model) env.align_engine = AlignEngine(config.align_model)
if __name_ == '__main__': if __name__ == '__main__':
webapp(init) webapp(init)