bugfix
This commit is contained in:
parent
29ace2ac85
commit
6f6e32e80f
91
app/align.py
91
app/align.py
@ -16,37 +16,63 @@ class AlignEngine:
|
|||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: str = "float16"
|
dtype: str = "float32"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
model_path: 本地模型路径(例如 /models/mms-aligner)
|
model_path: 本地模型路径
|
||||||
device: cuda / cpu
|
device: cuda / cpu
|
||||||
dtype: float16 / float32
|
dtype: float16 / float32
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.device = device
|
# ---------------------------
|
||||||
|
# 1. device 自动兜底
|
||||||
|
# ---------------------------
|
||||||
|
if device == "cuda" and not torch.cuda.is_available():
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
# dtype处理
|
self.device = torch.device(device)
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 2. dtype 统一管理
|
||||||
|
# ---------------------------
|
||||||
|
if self.device.type == "cpu":
|
||||||
|
# CPU 强制 fp32(避免 half kernel 问题)
|
||||||
|
self.dtype = torch.float32
|
||||||
|
else:
|
||||||
if dtype == "float16":
|
if dtype == "float16":
|
||||||
self.dtype = torch.float16
|
self.dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
|
|
||||||
# 加载processor + model(本地)
|
# ---------------------------
|
||||||
|
# 3. 加载 processor(不需要 dtype)
|
||||||
|
# ---------------------------
|
||||||
self.processor = AutoProcessor.from_pretrained(model_path)
|
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 4. 加载 model(关键修复点)
|
||||||
|
# ---------------------------
|
||||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
|
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
|
||||||
|
|
||||||
self.model.to(self.device)
|
# ⭐ 正确方式:先上 device,再统一 dtype
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
|
# ⚠️ CPU 不允许 half
|
||||||
|
if self.device.type == "cuda":
|
||||||
|
self.model = self.model.to(dtype=self.dtype)
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
if self.device == "cuda":
|
# ---------------------------
|
||||||
self.model = self.model.half()
|
# 5. vocab cache
|
||||||
|
# ---------------------------
|
||||||
# vocab缓存(避免重复计算)
|
|
||||||
vocab = self.processor.tokenizer.get_vocab()
|
vocab = self.processor.tokenizer.get_vocab()
|
||||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||||
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
|
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# 6. audio config
|
||||||
|
# ---------------------------
|
||||||
self.sample_rate = 16000
|
self.sample_rate = 16000
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@ -66,18 +92,32 @@ class AlignEngine:
|
|||||||
# logits计算
|
# logits计算
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
def get_logits(self, speech):
|
def get_logits(self, speech):
|
||||||
inputs = self.processor(
|
|
||||||
speech,
|
|
||||||
sampling_rate=self.sample_rate,
|
|
||||||
return_tensors="pt"
|
|
||||||
)
|
|
||||||
|
|
||||||
input_values = inputs.input_values.to(self.device)
|
# 1. 确保是 tensor
|
||||||
|
input_values = torch.tensor(speech)
|
||||||
|
|
||||||
|
# 2. 去掉多余维度(🔥关键)
|
||||||
|
input_values = input_values.squeeze()
|
||||||
|
|
||||||
|
# 3. 如果变成 1D → 加 batch
|
||||||
|
if input_values.dim() == 1:
|
||||||
|
input_values = input_values.unsqueeze(0)
|
||||||
|
|
||||||
|
# 4. 强制 shape = [B, T]
|
||||||
|
if input_values.dim() != 2:
|
||||||
|
raise ValueError(f"Invalid input shape: {input_values.shape}")
|
||||||
|
|
||||||
|
# 5. move device + dtype
|
||||||
|
input_values = input_values.to(self.device).to(self.dtype)
|
||||||
|
|
||||||
|
# 6. inference
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_values).logits
|
logits = self.model(input_values).logits
|
||||||
|
|
||||||
return logits[0].detach().cpu().numpy()
|
logits = logits.detach().cpu().numpy()
|
||||||
|
if logits.ndim == 3:
|
||||||
|
logits = logits[0]
|
||||||
|
return logits
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# 主对齐函数(逐字)
|
# 主对齐函数(逐字)
|
||||||
@ -117,17 +157,18 @@ class AlignEngine:
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
base = utt_begin_indices[0]
|
T = min(len(chars), len(char_probs))
|
||||||
|
|
||||||
for i, c in enumerate(chars):
|
for i in range(T):
|
||||||
start = timings[base + i]
|
c = chars[i]
|
||||||
end = timings[base + i + 1]
|
|
||||||
|
start = (i / T) * audio_duration
|
||||||
|
end = ((i + 1) / T) * audio_duration
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"char": c,
|
"char": c,
|
||||||
"start": float(start * frame_duration),
|
"start": start,
|
||||||
"end": float(end * frame_duration),
|
"end": end,
|
||||||
"prob": float(char_probs[base + i])
|
"prob": float(char_probs[i])
|
||||||
})
|
})
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@ -5,11 +5,14 @@ from ahserver.filestorage import FileStorage
|
|||||||
from appPublic.worker import awaitify
|
from appPublic.worker import awaitify
|
||||||
from appPublic.jsonConfig import getConfig
|
from appPublic.jsonConfig import getConfig
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
|
||||||
async def align(request, *params, params_kw={}):
|
async def lyric_align(request, params_kw, *params, **kw):
|
||||||
audio_webpath = params_kw.audio_path
|
audio_webpath = params_kw.audio_path
|
||||||
text = params_kw.text
|
text = params_kw.text
|
||||||
|
debug(f'{params_kw=}')
|
||||||
if audio_webpath is None:
|
if audio_webpath is None:
|
||||||
|
exception(f'{params_kw=}')
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"data": {
|
"data": {
|
||||||
@ -17,6 +20,7 @@ async def align(request, *params, params_kw={}):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if text is None:
|
if text is None:
|
||||||
|
exception(f'{params_kw=}')
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"data": {
|
"data": {
|
||||||
@ -26,21 +30,28 @@ async def align(request, *params, params_kw={}):
|
|||||||
env = ServerEnv()
|
env = ServerEnv()
|
||||||
fs = FileStorage()
|
fs = FileStorage()
|
||||||
audio_path = fs.realPath(audio_webpath)
|
audio_path = fs.realPath(audio_webpath)
|
||||||
align = awaitfy(env.align_engine.align)
|
align = awaitify(env.align_engine.align)
|
||||||
s = await align(audio_path, text)
|
text = text.replace(" ", "").replace('\t', "")
|
||||||
|
text1 = text.replace("\n", "")
|
||||||
|
s = await align(audio_path, text1)
|
||||||
|
debug(f'{s=}, {text1=},{text=}')
|
||||||
lines = text.split('\n')
|
lines = text.split('\n')
|
||||||
c_pos = 0
|
c_pos = 0
|
||||||
|
c_max = len(s) - 1
|
||||||
sentences = []
|
sentences = []
|
||||||
for l in lines:
|
for l in lines:
|
||||||
if l:
|
if l:
|
||||||
segment={
|
segment={
|
||||||
'sentence': l,
|
'sentence': l,
|
||||||
'start': s[c_pos]['start']
|
'start': s[c_pos]['start'],
|
||||||
'chars':[]
|
'chars':[]
|
||||||
}
|
}
|
||||||
for c in l:
|
for c in l:
|
||||||
c_pos += 1
|
if c_pos <= c_max:
|
||||||
segment['chars'].append(s[c_pos])
|
segment['chars'].append(s[c_pos])
|
||||||
|
else:
|
||||||
|
debug(f'{l=}, {c=}, {c_max=}')
|
||||||
|
c_pos += 1
|
||||||
|
|
||||||
segment['end'] = s[c_pos -1]['end']
|
segment['end'] = s[c_pos -1]['end']
|
||||||
sentences.append(segment)
|
sentences.append(segment)
|
||||||
@ -48,10 +59,10 @@ async def align(request, *params, params_kw={}):
|
|||||||
|
|
||||||
def init():
|
def init():
|
||||||
rf = RegisterFunction()
|
rf = RegisterFunction()
|
||||||
rf.register('align', align)
|
rf.register('align', lyric_align)
|
||||||
env = ServerEnv()
|
env = ServerEnv()
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
env.align_engine = AlignEngine(config.align_model)
|
env.align_engine = AlignEngine(config.align_model)
|
||||||
|
|
||||||
if __name_ == '__main__':
|
if __name__ == '__main__':
|
||||||
webapp(init)
|
webapp(init)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user