This commit is contained in:
yumoqing 2026-04-17 19:00:07 +08:00
parent 29ace2ac85
commit 6f6e32e80f
2 changed files with 193 additions and 141 deletions

View File

@ -16,37 +16,63 @@ class AlignEngine:
self, self,
model_path: str, model_path: str,
device: str = "cuda", device: str = "cuda",
dtype: str = "float16" dtype: str = "float32"
): ):
""" """
model_path: 本地模型路径例如 /models/mms-aligner model_path: 本地模型路径
device: cuda / cpu device: cuda / cpu
dtype: float16 / float32 dtype: float16 / float32
""" """
self.device = device # ---------------------------
# 1. device 自动兜底
# ---------------------------
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
# dtype处理 self.device = torch.device(device)
# ---------------------------
# 2. dtype 统一管理
# ---------------------------
if self.device.type == "cpu":
# CPU 强制 fp32避免 half kernel 问题)
self.dtype = torch.float32
else:
if dtype == "float16": if dtype == "float16":
self.dtype = torch.float16 self.dtype = torch.float16
else: else:
self.dtype = torch.float32 self.dtype = torch.float32
# 加载processor + model本地 # ---------------------------
# 3. 加载 processor不需要 dtype
# ---------------------------
self.processor = AutoProcessor.from_pretrained(model_path) self.processor = AutoProcessor.from_pretrained(model_path)
# ---------------------------
# 4. 加载 model关键修复点
# ---------------------------
self.model = Wav2Vec2ForCTC.from_pretrained(model_path) self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
self.model.to(self.device) # ⭐ 正确方式:先上 device再统一 dtype
self.model = self.model.to(self.device)
# ⚠️ CPU 不允许 half
if self.device.type == "cuda":
self.model = self.model.to(dtype=self.dtype)
self.model.eval() self.model.eval()
if self.device == "cuda": # ---------------------------
self.model = self.model.half() # 5. vocab cache
# ---------------------------
# vocab缓存避免重复计算
vocab = self.processor.tokenizer.get_vocab() vocab = self.processor.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()} self.inv_vocab = {v: k for k, v in vocab.items()}
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))] self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
# ---------------------------
# 6. audio config
# ---------------------------
self.sample_rate = 16000 self.sample_rate = 16000
# ----------------------------- # -----------------------------
@ -66,18 +92,32 @@ class AlignEngine:
# logits计算 # logits计算
# ----------------------------- # -----------------------------
def get_logits(self, speech): def get_logits(self, speech):
inputs = self.processor(
speech,
sampling_rate=self.sample_rate,
return_tensors="pt"
)
input_values = inputs.input_values.to(self.device) # 1. 确保是 tensor
input_values = torch.tensor(speech)
# 2. 去掉多余维度(🔥关键)
input_values = input_values.squeeze()
# 3. 如果变成 1D → 加 batch
if input_values.dim() == 1:
input_values = input_values.unsqueeze(0)
# 4. 强制 shape = [B, T]
if input_values.dim() != 2:
raise ValueError(f"Invalid input shape: {input_values.shape}")
# 5. move device + dtype
input_values = input_values.to(self.device).to(self.dtype)
# 6. inference
with torch.no_grad(): with torch.no_grad():
logits = self.model(input_values).logits logits = self.model(input_values).logits
return logits[0].detach().cpu().numpy() logits = logits.detach().cpu().numpy()
if logits.ndim == 3:
logits = logits[0]
return logits
# ----------------------------- # -----------------------------
# 主对齐函数(逐字) # 主对齐函数(逐字)
@ -117,17 +157,18 @@ class AlignEngine:
results = [] results = []
base = utt_begin_indices[0] T = min(len(chars), len(char_probs))
for i, c in enumerate(chars): for i in range(T):
start = timings[base + i] c = chars[i]
end = timings[base + i + 1]
start = (i / T) * audio_duration
end = ((i + 1) / T) * audio_duration
results.append({ results.append({
"char": c, "char": c,
"start": float(start * frame_duration), "start": start,
"end": float(end * frame_duration), "end": end,
"prob": float(char_probs[base + i]) "prob": float(char_probs[i])
}) })
return results return results

View File

@ -5,11 +5,14 @@ from ahserver.filestorage import FileStorage
from appPublic.worker import awaitify from appPublic.worker import awaitify
from appPublic.jsonConfig import getConfig from appPublic.jsonConfig import getConfig
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception
async def align(request, *params, params_kw={}): async def lyric_align(request, params_kw, *params, **kw):
audio_webpath = params_kw.audio_path audio_webpath = params_kw.audio_path
text = params_kw.text text = params_kw.text
debug(f'{params_kw=}')
if audio_webpath is None: if audio_webpath is None:
exception(f'{params_kw=}')
return { return {
"status": "error", "status": "error",
"data": { "data": {
@ -17,6 +20,7 @@ async def align(request, *params, params_kw={}):
} }
} }
if text is None: if text is None:
exception(f'{params_kw=}')
return { return {
"status": "error", "status": "error",
"data": { "data": {
@ -26,21 +30,28 @@ async def align(request, *params, params_kw={}):
env = ServerEnv() env = ServerEnv()
fs = FileStorage() fs = FileStorage()
audio_path = fs.realPath(audio_webpath) audio_path = fs.realPath(audio_webpath)
align = awaitfy(env.align_engine.align) align = awaitify(env.align_engine.align)
s = await align(audio_path, text) text = text.replace(" ", "").replace('\t', "")
text1 = text.replace("\n", "")
s = await align(audio_path, text1)
debug(f'{s=}, {text1=},{text=}')
lines = text.split('\n') lines = text.split('\n')
c_pos = 0 c_pos = 0
c_max = len(s) - 1
sentences = [] sentences = []
for l in lines: for l in lines:
if l: if l:
segment={ segment={
'sentence': l, 'sentence': l,
'start': s[c_pos]['start'] 'start': s[c_pos]['start'],
'chars':[] 'chars':[]
} }
for c in l: for c in l:
c_pos += 1 if c_pos <= c_max:
segment['chars'].append(s[c_pos]) segment['chars'].append(s[c_pos])
else:
debug(f'{l=}, {c=}, {c_max=}')
c_pos += 1
segment['end'] = s[c_pos -1]['end'] segment['end'] = s[c_pos -1]['end']
sentences.append(segment) sentences.append(segment)
@ -48,10 +59,10 @@ async def align(request, *params, params_kw={}):
def init(): def init():
rf = RegisterFunction() rf = RegisterFunction()
rf.register('align', align) rf.register('align', lyric_align)
env = ServerEnv() env = ServerEnv()
config = getConfig() config = getConfig()
env.align_engine = AlignEngine(config.align_model) env.align_engine = AlignEngine(config.align_model)
if __name_ == '__main__': if __name__ == '__main__':
webapp(init) webapp(init)