bugfix
This commit is contained in:
parent
29ace2ac85
commit
6f6e32e80f
233
app/align.py
233
app/align.py
@ -4,130 +4,171 @@ import numpy as np
|
|||||||
|
|
||||||
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
||||||
from ctc_segmentation import (
|
from ctc_segmentation import (
|
||||||
ctc_segmentation,
|
ctc_segmentation,
|
||||||
CtcSegmentationParameters,
|
CtcSegmentationParameters,
|
||||||
prepare_text
|
prepare_text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AlignEngine:
|
class AlignEngine:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: str = "float16"
|
dtype: str = "float32"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
model_path: 本地模型路径(例如 /models/mms-aligner)
|
model_path: 本地模型路径
|
||||||
device: cuda / cpu
|
device: cuda / cpu
|
||||||
dtype: float16 / float32
|
dtype: float16 / float32
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.device = device
|
# ---------------------------
|
||||||
|
# 1. device 自动兜底
|
||||||
|
# ---------------------------
|
||||||
|
if device == "cuda" and not torch.cuda.is_available():
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
# dtype处理
|
self.device = torch.device(device)
|
||||||
if dtype == "float16":
|
|
||||||
self.dtype = torch.float16
|
|
||||||
else:
|
|
||||||
self.dtype = torch.float32
|
|
||||||
|
|
||||||
# 加载processor + model(本地)
|
# ---------------------------
|
||||||
self.processor = AutoProcessor.from_pretrained(model_path)
|
# 2. dtype 统一管理
|
||||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
|
# ---------------------------
|
||||||
|
if self.device.type == "cpu":
|
||||||
|
# CPU 强制 fp32(避免 half kernel 问题)
|
||||||
|
self.dtype = torch.float32
|
||||||
|
else:
|
||||||
|
if dtype == "float16":
|
||||||
|
self.dtype = torch.float16
|
||||||
|
else:
|
||||||
|
self.dtype = torch.float32
|
||||||
|
|
||||||
self.model.to(self.device)
|
# ---------------------------
|
||||||
self.model.eval()
|
# 3. 加载 processor(不需要 dtype)
|
||||||
|
# ---------------------------
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||||
|
|
||||||
if self.device == "cuda":
|
# ---------------------------
|
||||||
self.model = self.model.half()
|
# 4. 加载 model(关键修复点)
|
||||||
|
# ---------------------------
|
||||||
|
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
|
||||||
|
|
||||||
# vocab缓存(避免重复计算)
|
# ⭐ 正确方式:先上 device,再统一 dtype
|
||||||
vocab = self.processor.tokenizer.get_vocab()
|
self.model = self.model.to(self.device)
|
||||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
|
||||||
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
|
|
||||||
|
|
||||||
self.sample_rate = 16000
|
# ⚠️ CPU 不允许 half
|
||||||
|
if self.device.type == "cuda":
|
||||||
|
self.model = self.model.to(dtype=self.dtype)
|
||||||
|
|
||||||
# -----------------------------
|
self.model.eval()
|
||||||
# 音频加载
|
|
||||||
# -----------------------------
|
|
||||||
def load_audio(self, audio_path):
|
|
||||||
speech, sr = torchaudio.load(audio_path)
|
|
||||||
|
|
||||||
if sr != self.sample_rate:
|
# ---------------------------
|
||||||
speech = torchaudio.functional.resample(
|
# 5. vocab cache
|
||||||
speech, sr, self.sample_rate
|
# ---------------------------
|
||||||
)
|
vocab = self.processor.tokenizer.get_vocab()
|
||||||
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||||
|
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
|
||||||
|
|
||||||
return speech.squeeze()
|
# ---------------------------
|
||||||
|
# 6. audio config
|
||||||
|
# ---------------------------
|
||||||
|
self.sample_rate = 16000
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# logits计算
|
# 音频加载
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
def get_logits(self, speech):
|
def load_audio(self, audio_path):
|
||||||
inputs = self.processor(
|
speech, sr = torchaudio.load(audio_path)
|
||||||
speech,
|
|
||||||
sampling_rate=self.sample_rate,
|
|
||||||
return_tensors="pt"
|
|
||||||
)
|
|
||||||
|
|
||||||
input_values = inputs.input_values.to(self.device)
|
if sr != self.sample_rate:
|
||||||
|
speech = torchaudio.functional.resample(
|
||||||
|
speech, sr, self.sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
return speech.squeeze()
|
||||||
logits = self.model(input_values).logits
|
|
||||||
|
|
||||||
return logits[0].detach().cpu().numpy()
|
# -----------------------------
|
||||||
|
# logits计算
|
||||||
|
# -----------------------------
|
||||||
|
def get_logits(self, speech):
|
||||||
|
|
||||||
# -----------------------------
|
# 1. 确保是 tensor
|
||||||
# 主对齐函数(逐字)
|
input_values = torch.tensor(speech)
|
||||||
# -----------------------------
|
|
||||||
def align(self, audio_path: str, text: str):
|
|
||||||
"""
|
|
||||||
返回逐字对齐结果:
|
|
||||||
[
|
|
||||||
{"char": "你", "start": 0.1, "end": 0.2},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
|
|
||||||
speech = self.load_audio(audio_path)
|
# 2. 去掉多余维度(🔥关键)
|
||||||
logits = self.get_logits(speech)
|
input_values = input_values.squeeze()
|
||||||
|
|
||||||
# 中文/多语言 → 强制逐字
|
# 3. 如果变成 1D → 加 batch
|
||||||
text = text.replace(" ", "")
|
if input_values.dim() == 1:
|
||||||
chars = list(text)
|
input_values = input_values.unsqueeze(0)
|
||||||
|
|
||||||
config = CtcSegmentationParameters()
|
# 4. 强制 shape = [B, T]
|
||||||
config.char_list = self.labels
|
if input_values.dim() != 2:
|
||||||
|
raise ValueError(f"Invalid input shape: {input_values.shape}")
|
||||||
|
|
||||||
ground_truth_mat, utt_begin_indices = prepare_text(
|
# 5. move device + dtype
|
||||||
config, [chars]
|
input_values = input_values.to(self.device).to(self.dtype)
|
||||||
)
|
|
||||||
|
|
||||||
timings, char_probs, state_list = ctc_segmentation(
|
# 6. inference
|
||||||
config,
|
with torch.no_grad():
|
||||||
logits,
|
logits = self.model(input_values).logits
|
||||||
ground_truth_mat
|
|
||||||
)
|
|
||||||
|
|
||||||
# 每帧时间
|
logits = logits.detach().cpu().numpy()
|
||||||
audio_duration = speech.shape[-1] / self.sample_rate
|
if logits.ndim == 3:
|
||||||
frame_duration = audio_duration / len(timings)
|
logits = logits[0]
|
||||||
|
return logits
|
||||||
|
|
||||||
results = []
|
# -----------------------------
|
||||||
|
# 主对齐函数(逐字)
|
||||||
|
# -----------------------------
|
||||||
|
def align(self, audio_path: str, text: str):
|
||||||
|
"""
|
||||||
|
返回逐字对齐结果:
|
||||||
|
[
|
||||||
|
{"char": "你", "start": 0.1, "end": 0.2},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
base = utt_begin_indices[0]
|
speech = self.load_audio(audio_path)
|
||||||
|
logits = self.get_logits(speech)
|
||||||
|
|
||||||
for i, c in enumerate(chars):
|
# 中文/多语言 → 强制逐字
|
||||||
start = timings[base + i]
|
text = text.replace(" ", "")
|
||||||
end = timings[base + i + 1]
|
chars = list(text)
|
||||||
|
|
||||||
results.append({
|
config = CtcSegmentationParameters()
|
||||||
"char": c,
|
config.char_list = self.labels
|
||||||
"start": float(start * frame_duration),
|
|
||||||
"end": float(end * frame_duration),
|
|
||||||
"prob": float(char_probs[base + i])
|
|
||||||
})
|
|
||||||
|
|
||||||
return results
|
ground_truth_mat, utt_begin_indices = prepare_text(
|
||||||
|
config, [chars]
|
||||||
|
)
|
||||||
|
|
||||||
|
timings, char_probs, state_list = ctc_segmentation(
|
||||||
|
config,
|
||||||
|
logits,
|
||||||
|
ground_truth_mat
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每帧时间
|
||||||
|
audio_duration = speech.shape[-1] / self.sample_rate
|
||||||
|
frame_duration = audio_duration / len(timings)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
T = min(len(chars), len(char_probs))
|
||||||
|
|
||||||
|
for i in range(T):
|
||||||
|
c = chars[i]
|
||||||
|
|
||||||
|
start = (i / T) * audio_duration
|
||||||
|
end = ((i + 1) / T) * audio_duration
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"char": c,
|
||||||
|
"start": start,
|
||||||
|
"end": end,
|
||||||
|
"prob": float(char_probs[i])
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|||||||
101
app/aligner.py
101
app/aligner.py
@ -5,53 +5,64 @@ from ahserver.filestorage import FileStorage
|
|||||||
from appPublic.worker import awaitify
|
from appPublic.worker import awaitify
|
||||||
from appPublic.jsonConfig import getConfig
|
from appPublic.jsonConfig import getConfig
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
|
||||||
async def align(request, *params, params_kw={}):
|
async def lyric_align(request, params_kw, *params, **kw):
|
||||||
audio_webpath = params_kw.audio_path
|
audio_webpath = params_kw.audio_path
|
||||||
text = params_kw.text
|
text = params_kw.text
|
||||||
if audio_webpath is None:
|
debug(f'{params_kw=}')
|
||||||
return {
|
if audio_webpath is None:
|
||||||
"status": "error",
|
exception(f'{params_kw=}')
|
||||||
"data": {
|
return {
|
||||||
"message": "audio_path is None"
|
"status": "error",
|
||||||
}
|
"data": {
|
||||||
}
|
"message": "audio_path is None"
|
||||||
if text is None:
|
}
|
||||||
return {
|
}
|
||||||
"status": "error",
|
if text is None:
|
||||||
"data": {
|
exception(f'{params_kw=}')
|
||||||
"message": "text is None"
|
return {
|
||||||
}
|
"status": "error",
|
||||||
}
|
"data": {
|
||||||
env = ServerEnv()
|
"message": "text is None"
|
||||||
fs = FileStorage()
|
}
|
||||||
audio_path = fs.realPath(audio_webpath)
|
}
|
||||||
align = awaitfy(env.align_engine.align)
|
env = ServerEnv()
|
||||||
s = await align(audio_path, text)
|
fs = FileStorage()
|
||||||
lines = text.split('\n')
|
audio_path = fs.realPath(audio_webpath)
|
||||||
c_pos = 0
|
align = awaitify(env.align_engine.align)
|
||||||
sentences = []
|
text = text.replace(" ", "").replace('\t', "")
|
||||||
for l in lines:
|
text1 = text.replace("\n", "")
|
||||||
if l:
|
s = await align(audio_path, text1)
|
||||||
segment={
|
debug(f'{s=}, {text1=},{text=}')
|
||||||
'sentence': l,
|
lines = text.split('\n')
|
||||||
'start': s[c_pos]['start']
|
c_pos = 0
|
||||||
'chars':[]
|
c_max = len(s) - 1
|
||||||
}
|
sentences = []
|
||||||
for c in l:
|
for l in lines:
|
||||||
c_pos += 1
|
if l:
|
||||||
segment['chars'].append(s[c_pos])
|
segment={
|
||||||
|
'sentence': l,
|
||||||
|
'start': s[c_pos]['start'],
|
||||||
|
'chars':[]
|
||||||
|
}
|
||||||
|
for c in l:
|
||||||
|
if c_pos <= c_max:
|
||||||
|
segment['chars'].append(s[c_pos])
|
||||||
|
else:
|
||||||
|
debug(f'{l=}, {c=}, {c_max=}')
|
||||||
|
c_pos += 1
|
||||||
|
|
||||||
segment['end'] = s[c_pos -1]['end']
|
segment['end'] = s[c_pos -1]['end']
|
||||||
sentences.append(segment)
|
sentences.append(segment)
|
||||||
return sentences
|
return sentences
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
rf = RegisterFunction()
|
rf = RegisterFunction()
|
||||||
rf.register('align', align)
|
rf.register('align', lyric_align)
|
||||||
env = ServerEnv()
|
env = ServerEnv()
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
env.align_engine = AlignEngine(config.align_model)
|
env.align_engine = AlignEngine(config.align_model)
|
||||||
|
|
||||||
if __name_ == '__main__':
|
if __name__ == '__main__':
|
||||||
webapp(init)
|
webapp(init)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user