bugfix
This commit is contained in:
parent
29ace2ac85
commit
6f6e32e80f
91
app/align.py
91
app/align.py
@ -16,37 +16,63 @@ class AlignEngine:
|
||||
self,
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
dtype: str = "float16"
|
||||
dtype: str = "float32"
|
||||
):
|
||||
"""
|
||||
model_path: 本地模型路径(例如 /models/mms-aligner)
|
||||
model_path: 本地模型路径
|
||||
device: cuda / cpu
|
||||
dtype: float16 / float32
|
||||
"""
|
||||
|
||||
self.device = device
|
||||
# ---------------------------
|
||||
# 1. device 自动兜底
|
||||
# ---------------------------
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
device = "cpu"
|
||||
|
||||
# dtype处理
|
||||
self.device = torch.device(device)
|
||||
|
||||
# ---------------------------
|
||||
# 2. dtype 统一管理
|
||||
# ---------------------------
|
||||
if self.device.type == "cpu":
|
||||
# CPU 强制 fp32(避免 half kernel 问题)
|
||||
self.dtype = torch.float32
|
||||
else:
|
||||
if dtype == "float16":
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
|
||||
# 加载processor + model(本地)
|
||||
# ---------------------------
|
||||
# 3. 加载 processor(不需要 dtype)
|
||||
# ---------------------------
|
||||
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||
|
||||
# ---------------------------
|
||||
# 4. 加载 model(关键修复点)
|
||||
# ---------------------------
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
|
||||
|
||||
self.model.to(self.device)
|
||||
# ⭐ 正确方式:先上 device,再统一 dtype
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# ⚠️ CPU 不允许 half
|
||||
if self.device.type == "cuda":
|
||||
self.model = self.model.to(dtype=self.dtype)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
if self.device == "cuda":
|
||||
self.model = self.model.half()
|
||||
|
||||
# vocab缓存(避免重复计算)
|
||||
# ---------------------------
|
||||
# 5. vocab cache
|
||||
# ---------------------------
|
||||
vocab = self.processor.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
|
||||
|
||||
# ---------------------------
|
||||
# 6. audio config
|
||||
# ---------------------------
|
||||
self.sample_rate = 16000
|
||||
|
||||
# -----------------------------
|
||||
@ -66,18 +92,32 @@ class AlignEngine:
|
||||
# logits计算
|
||||
# -----------------------------
|
||||
def get_logits(self, speech):
|
||||
inputs = self.processor(
|
||||
speech,
|
||||
sampling_rate=self.sample_rate,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
input_values = inputs.input_values.to(self.device)
|
||||
# 1. 确保是 tensor
|
||||
input_values = torch.tensor(speech)
|
||||
|
||||
# 2. 去掉多余维度(🔥关键)
|
||||
input_values = input_values.squeeze()
|
||||
|
||||
# 3. 如果变成 1D → 加 batch
|
||||
if input_values.dim() == 1:
|
||||
input_values = input_values.unsqueeze(0)
|
||||
|
||||
# 4. 强制 shape = [B, T]
|
||||
if input_values.dim() != 2:
|
||||
raise ValueError(f"Invalid input shape: {input_values.shape}")
|
||||
|
||||
# 5. move device + dtype
|
||||
input_values = input_values.to(self.device).to(self.dtype)
|
||||
|
||||
# 6. inference
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_values).logits
|
||||
|
||||
return logits[0].detach().cpu().numpy()
|
||||
logits = logits.detach().cpu().numpy()
|
||||
if logits.ndim == 3:
|
||||
logits = logits[0]
|
||||
return logits
|
||||
|
||||
# -----------------------------
|
||||
# 主对齐函数(逐字)
|
||||
@ -117,17 +157,18 @@ class AlignEngine:
|
||||
|
||||
results = []
|
||||
|
||||
base = utt_begin_indices[0]
|
||||
T = min(len(chars), len(char_probs))
|
||||
|
||||
for i, c in enumerate(chars):
|
||||
start = timings[base + i]
|
||||
end = timings[base + i + 1]
|
||||
for i in range(T):
|
||||
c = chars[i]
|
||||
|
||||
start = (i / T) * audio_duration
|
||||
end = ((i + 1) / T) * audio_duration
|
||||
|
||||
results.append({
|
||||
"char": c,
|
||||
"start": float(start * frame_duration),
|
||||
"end": float(end * frame_duration),
|
||||
"prob": float(char_probs[base + i])
|
||||
"start": start,
|
||||
"end": end,
|
||||
"prob": float(char_probs[i])
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
@ -5,11 +5,14 @@ from ahserver.filestorage import FileStorage
|
||||
from appPublic.worker import awaitify
|
||||
from appPublic.jsonConfig import getConfig
|
||||
from appPublic.registerfunction import RegisterFunction
|
||||
from appPublic.log import debug, exception
|
||||
|
||||
async def align(request, *params, params_kw={}):
|
||||
async def lyric_align(request, params_kw, *params, **kw):
|
||||
audio_webpath = params_kw.audio_path
|
||||
text = params_kw.text
|
||||
debug(f'{params_kw=}')
|
||||
if audio_webpath is None:
|
||||
exception(f'{params_kw=}')
|
||||
return {
|
||||
"status": "error",
|
||||
"data": {
|
||||
@ -17,6 +20,7 @@ async def align(request, *params, params_kw={}):
|
||||
}
|
||||
}
|
||||
if text is None:
|
||||
exception(f'{params_kw=}')
|
||||
return {
|
||||
"status": "error",
|
||||
"data": {
|
||||
@ -26,21 +30,28 @@ async def align(request, *params, params_kw={}):
|
||||
env = ServerEnv()
|
||||
fs = FileStorage()
|
||||
audio_path = fs.realPath(audio_webpath)
|
||||
align = awaitfy(env.align_engine.align)
|
||||
s = await align(audio_path, text)
|
||||
align = awaitify(env.align_engine.align)
|
||||
text = text.replace(" ", "").replace('\t', "")
|
||||
text1 = text.replace("\n", "")
|
||||
s = await align(audio_path, text1)
|
||||
debug(f'{s=}, {text1=},{text=}')
|
||||
lines = text.split('\n')
|
||||
c_pos = 0
|
||||
c_max = len(s) - 1
|
||||
sentences = []
|
||||
for l in lines:
|
||||
if l:
|
||||
segment={
|
||||
'sentence': l,
|
||||
'start': s[c_pos]['start']
|
||||
'start': s[c_pos]['start'],
|
||||
'chars':[]
|
||||
}
|
||||
for c in l:
|
||||
c_pos += 1
|
||||
if c_pos <= c_max:
|
||||
segment['chars'].append(s[c_pos])
|
||||
else:
|
||||
debug(f'{l=}, {c=}, {c_max=}')
|
||||
c_pos += 1
|
||||
|
||||
segment['end'] = s[c_pos -1]['end']
|
||||
sentences.append(segment)
|
||||
@ -48,10 +59,10 @@ async def align(request, *params, params_kw={}):
|
||||
|
||||
def init():
|
||||
rf = RegisterFunction()
|
||||
rf.register('align', align)
|
||||
rf.register('align', lyric_align)
|
||||
env = ServerEnv()
|
||||
config = getConfig()
|
||||
env.align_engine = AlignEngine(config.align_model)
|
||||
|
||||
if __name_ == '__main__':
|
||||
if __name__ == '__main__':
|
||||
webapp(init)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user