commit ebe895f7fb51539a7246cfcf595617a1b17345ae Author: yumoqing Date: Fri Apr 17 15:16:12 2026 +0800 first commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..ad537a2 --- /dev/null +++ b/README.md @@ -0,0 +1,31 @@ +# 音频文本对齐服务 +本服务部署在有GPU的主机上 + +## api +请求格式 +``` +curl -X POST https://server:port/align \ + -H "Content-Type: application/json" \ + -F "text=音频中的文字" \ + -F "audio_file=@/path/to/音频文件" +``` + +输出: +``` +[ + { + "sentence": "世界啊你好", + "start": 0.123, + "end": 1.45, + "chars":[ + { + "char": "世", + "start":0.123, + "end": 0.543 + }, + ... + ] + } + ... +] +``` diff --git a/app/align.py b/app/align.py new file mode 100644 index 0000000..d9923c8 --- /dev/null +++ b/app/align.py @@ -0,0 +1,133 @@ +import torch +import torchaudio +import numpy as np + +from transformers import AutoProcessor, Wav2Vec2ForCTC +from ctc_segmentation import ( + ctc_segmentation, + CtcSegmentationParameters, + prepare_text +) + + +class AlignEngine: + + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: str = "float16" + ): + """ + model_path: 本地模型路径(例如 /models/mms-aligner) + device: cuda / cpu + dtype: float16 / float32 + """ + + self.device = device + + # dtype处理 + if dtype == "float16": + self.dtype = torch.float16 + else: + self.dtype = torch.float32 + + # 加载processor + model(本地) + self.processor = AutoProcessor.from_pretrained(model_path) + self.model = Wav2Vec2ForCTC.from_pretrained(model_path) + + self.model.to(self.device) + self.model.eval() + + if self.device == "cuda": + self.model = self.model.half() + + # vocab缓存(避免重复计算) + vocab = self.processor.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))] + + self.sample_rate = 16000 + + # ----------------------------- + # 音频加载 + # ----------------------------- + def load_audio(self, audio_path): + speech, sr = torchaudio.load(audio_path) + + if sr != self.sample_rate: + speech = torchaudio.functional.resample( + speech, sr, self.sample_rate + ) + + return speech.squeeze() + + # ----------------------------- + # logits计算 + # ----------------------------- + def get_logits(self, speech): + inputs = self.processor( + speech, + sampling_rate=self.sample_rate, + return_tensors="pt" + ) + + input_values = inputs.input_values.to(self.device) + + with torch.no_grad(): + logits = self.model(input_values).logits + + return logits[0].detach().cpu().numpy() + + # ----------------------------- + # 主对齐函数(逐字) + # ----------------------------- + def align(self, audio_path: str, text: str): + """ + 返回逐字对齐结果: + [ + {"char": "你", "start": 0.1, "end": 0.2}, + ... + ] + """ + + speech = self.load_audio(audio_path) + logits = self.get_logits(speech) + + # 中文/多语言 → 强制逐字 + text = text.replace(" ", "") + chars = list(text) + + config = CtcSegmentationParameters() + config.char_list = self.labels + + ground_truth_mat, utt_begin_indices = prepare_text( + config, [chars] + ) + + timings, char_probs, state_list = ctc_segmentation( + config, + logits, + ground_truth_mat + ) + + # 每帧时间 + audio_duration = speech.shape[-1] / self.sample_rate + frame_duration = audio_duration / len(timings) + + results = [] + + base = utt_begin_indices[0] + + for i, c in enumerate(chars): + start = timings[base + i] + end = timings[base + i + 1] + + results.append({ + "char": c, + "start": float(start * frame_duration), + "end": float(end * frame_duration), + "prob": float(char_probs[base + i]) + }) + + return results diff --git a/app/aligner.py b/app/aligner.py new file mode 100644 index 0000000..e6458f4 --- /dev/null +++ b/app/aligner.py @@ -0,0 +1,39 @@ +from align import AlignEngine +from ahserver.serverenv import ServerEnv +from ahserver.webapp import webapp +from ahserver.filestorage import FileStorage +from appPublic.worker import awaitify +from appPublic.jsonConfig import getConfig + +async def align(audio_webpath, text): + env = ServerEnv() + fs = FileStorage() + audio_path = fs.realPath(audio_webpath) + align = awaitfy(env.align_engine.align) + s = await align(audio_path, text) + lines = text.split('\n') + c_pos = 0 + sentences = [] + for l in lines: + if l: + segment={ + 'sentence': l, + 'start': s[c_pos]['start'] + 'chars':[] + } + for c in l: + c_pos += 1 + segment['chars'].append(s[c_pos]) + + segment['end'] = s[c_pos -1]['end'] + sentences.append(segment) + return sentences + +def init(): + env = ServerEnv() + config = getConfig() + env.align_engine = AlignEngine(config.align_model) + env.align = align + +if __name_ == '__main__': + webapp(init) diff --git a/conf/config.json b/conf/config.json new file mode 100644 index 0000000..acbeec7 --- /dev/null +++ b/conf/config.json @@ -0,0 +1,36 @@ +{ + "align_model": "/data/ymq/models/MahmoudAshraf/mms-300m-1130-forced-aligner", + "website":{ + "paths":[ + ["$[workdir]$/wwwroot",""] + ], + "client_max_size":1000000000, + "host":"0.0.0.0", + "port":8080, + "coding":"utf-8", + "indexes":[ + "index.html", + "index.ui", + "index.dspy" + ], + "processors":[ + [".proxy","proxy"], + [".tmpl.js","tmpl"], + [".tmpl.css","tmpl"], + [".html.tmpl","tmpl"], + [".tmpl","tmpl"], + [".app","app"], + [".ui","bui"], + [".dspy","dspy"] + ], + "startswiths":[ + { + "leading":"/idfile", + "registerfunction":"idfile" + },{ + "leading":"/i18n_getmsgs", + "registerfunction":"i18n" + } + ] + } +} diff --git a/files/README.md b/files/README.md new file mode 100644 index 0000000..e69de29 diff --git a/logs/README.md b/logs/README.md new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..72e88c7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch +torchaudio +transformers +librosa +ctc-segmentation