first commit
This commit is contained in:
commit
ebe895f7fb
31
README.md
Normal file
31
README.md
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# 音频文本对齐服务
|
||||||
|
本服务部署在有GPU的主机上
|
||||||
|
|
||||||
|
## api
|
||||||
|
请求格式
|
||||||
|
```
|
||||||
|
curl -X POST https://server:port/align \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-F "text=音频中的文字" \
|
||||||
|
-F "audio_file=@/path/to/音频文件"
|
||||||
|
```
|
||||||
|
|
||||||
|
输出:
|
||||||
|
```
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"sentence": "世界啊你好",
|
||||||
|
"start": 0.123,
|
||||||
|
"end": 1.45,
|
||||||
|
"chars":[
|
||||||
|
{
|
||||||
|
"char": "世",
|
||||||
|
"start":0.123,
|
||||||
|
"end": 0.543
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
...
|
||||||
|
]
|
||||||
|
```
|
||||||
133
app/align.py
Normal file
133
app/align.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
||||||
|
from ctc_segmentation import (
|
||||||
|
ctc_segmentation,
|
||||||
|
CtcSegmentationParameters,
|
||||||
|
prepare_text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AlignEngine:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: str = "float16"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
model_path: 本地模型路径(例如 /models/mms-aligner)
|
||||||
|
device: cuda / cpu
|
||||||
|
dtype: float16 / float32
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# dtype处理
|
||||||
|
if dtype == "float16":
|
||||||
|
self.dtype = torch.float16
|
||||||
|
else:
|
||||||
|
self.dtype = torch.float32
|
||||||
|
|
||||||
|
# 加载processor + model(本地)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||||
|
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
if self.device == "cuda":
|
||||||
|
self.model = self.model.half()
|
||||||
|
|
||||||
|
# vocab缓存(避免重复计算)
|
||||||
|
vocab = self.processor.tokenizer.get_vocab()
|
||||||
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||||
|
self.labels = [self.inv_vocab[i] for i in range(len(self.inv_vocab))]
|
||||||
|
|
||||||
|
self.sample_rate = 16000
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 音频加载
|
||||||
|
# -----------------------------
|
||||||
|
def load_audio(self, audio_path):
|
||||||
|
speech, sr = torchaudio.load(audio_path)
|
||||||
|
|
||||||
|
if sr != self.sample_rate:
|
||||||
|
speech = torchaudio.functional.resample(
|
||||||
|
speech, sr, self.sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
return speech.squeeze()
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# logits计算
|
||||||
|
# -----------------------------
|
||||||
|
def get_logits(self, speech):
|
||||||
|
inputs = self.processor(
|
||||||
|
speech,
|
||||||
|
sampling_rate=self.sample_rate,
|
||||||
|
return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
input_values = inputs.input_values.to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.model(input_values).logits
|
||||||
|
|
||||||
|
return logits[0].detach().cpu().numpy()
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 主对齐函数(逐字)
|
||||||
|
# -----------------------------
|
||||||
|
def align(self, audio_path: str, text: str):
|
||||||
|
"""
|
||||||
|
返回逐字对齐结果:
|
||||||
|
[
|
||||||
|
{"char": "你", "start": 0.1, "end": 0.2},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
speech = self.load_audio(audio_path)
|
||||||
|
logits = self.get_logits(speech)
|
||||||
|
|
||||||
|
# 中文/多语言 → 强制逐字
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
chars = list(text)
|
||||||
|
|
||||||
|
config = CtcSegmentationParameters()
|
||||||
|
config.char_list = self.labels
|
||||||
|
|
||||||
|
ground_truth_mat, utt_begin_indices = prepare_text(
|
||||||
|
config, [chars]
|
||||||
|
)
|
||||||
|
|
||||||
|
timings, char_probs, state_list = ctc_segmentation(
|
||||||
|
config,
|
||||||
|
logits,
|
||||||
|
ground_truth_mat
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每帧时间
|
||||||
|
audio_duration = speech.shape[-1] / self.sample_rate
|
||||||
|
frame_duration = audio_duration / len(timings)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
base = utt_begin_indices[0]
|
||||||
|
|
||||||
|
for i, c in enumerate(chars):
|
||||||
|
start = timings[base + i]
|
||||||
|
end = timings[base + i + 1]
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"char": c,
|
||||||
|
"start": float(start * frame_duration),
|
||||||
|
"end": float(end * frame_duration),
|
||||||
|
"prob": float(char_probs[base + i])
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
39
app/aligner.py
Normal file
39
app/aligner.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from align import AlignEngine
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.webapp import webapp
|
||||||
|
from ahserver.filestorage import FileStorage
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.jsonConfig import getConfig
|
||||||
|
|
||||||
|
async def align(audio_webpath, text):
|
||||||
|
env = ServerEnv()
|
||||||
|
fs = FileStorage()
|
||||||
|
audio_path = fs.realPath(audio_webpath)
|
||||||
|
align = awaitfy(env.align_engine.align)
|
||||||
|
s = await align(audio_path, text)
|
||||||
|
lines = text.split('\n')
|
||||||
|
c_pos = 0
|
||||||
|
sentences = []
|
||||||
|
for l in lines:
|
||||||
|
if l:
|
||||||
|
segment={
|
||||||
|
'sentence': l,
|
||||||
|
'start': s[c_pos]['start']
|
||||||
|
'chars':[]
|
||||||
|
}
|
||||||
|
for c in l:
|
||||||
|
c_pos += 1
|
||||||
|
segment['chars'].append(s[c_pos])
|
||||||
|
|
||||||
|
segment['end'] = s[c_pos -1]['end']
|
||||||
|
sentences.append(segment)
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
def init():
|
||||||
|
env = ServerEnv()
|
||||||
|
config = getConfig()
|
||||||
|
env.align_engine = AlignEngine(config.align_model)
|
||||||
|
env.align = align
|
||||||
|
|
||||||
|
if __name_ == '__main__':
|
||||||
|
webapp(init)
|
||||||
36
conf/config.json
Normal file
36
conf/config.json
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
{
|
||||||
|
"align_model": "/data/ymq/models/MahmoudAshraf/mms-300m-1130-forced-aligner",
|
||||||
|
"website":{
|
||||||
|
"paths":[
|
||||||
|
["$[workdir]$/wwwroot",""]
|
||||||
|
],
|
||||||
|
"client_max_size":1000000000,
|
||||||
|
"host":"0.0.0.0",
|
||||||
|
"port":8080,
|
||||||
|
"coding":"utf-8",
|
||||||
|
"indexes":[
|
||||||
|
"index.html",
|
||||||
|
"index.ui",
|
||||||
|
"index.dspy"
|
||||||
|
],
|
||||||
|
"processors":[
|
||||||
|
[".proxy","proxy"],
|
||||||
|
[".tmpl.js","tmpl"],
|
||||||
|
[".tmpl.css","tmpl"],
|
||||||
|
[".html.tmpl","tmpl"],
|
||||||
|
[".tmpl","tmpl"],
|
||||||
|
[".app","app"],
|
||||||
|
[".ui","bui"],
|
||||||
|
[".dspy","dspy"]
|
||||||
|
],
|
||||||
|
"startswiths":[
|
||||||
|
{
|
||||||
|
"leading":"/idfile",
|
||||||
|
"registerfunction":"idfile"
|
||||||
|
},{
|
||||||
|
"leading":"/i18n_getmsgs",
|
||||||
|
"registerfunction":"i18n"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
0
files/README.md
Normal file
0
files/README.md
Normal file
0
logs/README.md
Normal file
0
logs/README.md
Normal file
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
torch
|
||||||
|
torchaudio
|
||||||
|
transformers
|
||||||
|
librosa
|
||||||
|
ctc-segmentation
|
||||||
Loading…
x
Reference in New Issue
Block a user