yumoqing b9728f9bf8 refactor: 改为 PyTorch GPU 加速版本
- librosa 替换为 torch + torchaudio
- 音频直接加载到 GPU
- 预计算 STFT 共享给所有分析器(避免重复计算)
- 单首歌评估: ~200MB 显存, ~2秒 (4090)
- 评估完成自动释放 GPU 显存
2026-06-03 18:13:52 +08:00

70 lines
2.1 KiB
Python

"""音频质量分析 - GPU 版本"""
import torch
from . import safe_float
def analyze_quality(y, sr, stft_result=None):
"""分析音频质量维度 (GPU)"""
# 信噪比估算
frame_length = 2048
hop_length = 512
n_frames = (len(y) - frame_length) // hop_length + 1
if n_frames > 10:
frames = torch.nn.functional.unfold(
y.unsqueeze(0).unsqueeze(0),
kernel_size=(1, frame_length),
stride=(1, hop_length)
).squeeze(0).squeeze(0)
rms = torch.sqrt(torch.mean(frames ** 2, dim=0))
rms_sorted, _ = torch.sort(rms)
n = len(rms_sorted)
noise_floor = rms_sorted[:n // 10].mean()
signal_level = rms_sorted[-n // 10:].mean()
if noise_floor > 0:
snr_db = safe_float(20 * torch.log10(signal_level / noise_floor))
else:
snr_db = 60.0
else:
snr_db = 30.0
# 削波检测
clipped = (torch.abs(y) > 0.99).sum().item() / len(y)
clipping_score = 10.0 if clipped < 0.001 else max(0, 10.0 - clipped * 10000)
# 频率均衡 - 频谱平坦度
if stft_result is not None:
magnitude = stft_result.abs()
geometric_mean = torch.exp(torch.log(torch.clamp(magnitude, 1e-10)).mean(dim=0))
arithmetic_mean = magnitude.mean(dim=0)
flatness = geometric_mean / torch.clamp(arithmetic_mean, 1e-10)
flatness_mean = safe_float(flatness.mean())
else:
flatness_mean = 0.05
if 0.001 <= flatness_mean <= 0.1:
freq_balance = 8.0
elif flatness_mean < 0.001:
freq_balance = 6.0
else:
freq_balance = 5.0
scores = {
"snr": round(snr_db, 2),
"clipping": round(clipping_score, 2),
"frequency_balance": round(freq_balance, 2),
}
if snr_db >= 40:
snr_score = 10.0
elif snr_db >= 20:
snr_score = 6.0 + (snr_db - 20) / 20 * 4.0
else:
snr_score = max(snr_db / 20 * 6.0, 0)
score = 0.40 * snr_score + 0.35 * scores["clipping"] + 0.25 * scores["frequency_balance"]
scores["score"] = round(min(score, 10), 2)
return scores