返回

音频检测模型长音频评估优化:解决训练测试时长不匹配问题

Ai

解决音频检测模型在长音频评估中的问题:训练数据为 1.1 秒,测试数据更长

最近我在做大学项目的音频检测模型时碰到了一个问题。 模型是用 1.1 秒时长的音频训练的,在用窗口滑动方法处理较长的音频进行评估时,效果并不好,即便模型的测试准确率非常高,大部分时候能达到 97%。

学校给出的处理这类数据的方法如下:“最后一步是为接下来的挑战阶段做准备:定性评估你最好的分类器在 Moodle 上提供的 4 个场景中的性能。这些场景的时长为 6-24 秒,包含由设备关键词和动作关键词组成的完整语音命令(例如,“Radio aus”)。要将上一阶段的分类器应用于这些较长的录音,首先需要将它们切割成较短的 1.1 秒片段。为此,使用跳跃大小为 1 帧的滑动窗口来提取 44 帧的特征序列(44 个序列步长对应于 1.1 秒)。下图说明了从长度为 N 的梅尔频谱图中提取片段(红色窗口)的过程,其中跳跃大小为 h 帧,窗口大小为 w。”

目前我的代码如下:

import torchaudio
import torch
from collections import Counter

# 假设这些函数/类已定义
def normalize(tensor):
  #对张量在每个通道进行归一化。
    tensor_minusmean = tensor - tensor.mean()
    return tensor_minusmean/tensor_minusmean.abs().max()

class AudioUtil():
  def open(audio_file):
    sig, sr = torchaudio.load(audio_file)
    return (sig, sr)
  def spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None):
    sig,sr = aud
    top_db = 80
    spec = torchaudio.transforms.MelSpectrogram(sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)
    spec = torchaudio.transforms.AmplitudeToDB(top_db=top_db)(spec)
    return spec

class AudioClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu1 = torch.nn.ReLU()
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = torch.nn.ReLU()
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(32 * 16 * 11, 128)  # 调整这里的输入大小
        self.relu3 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(128, 4)  #  4 classes
        
    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

# Dummy class to represent 你的数据集的
class DummyDS:
  def __init__(self):
        self.idx2label = {0: "Radio", 1: "Küche", 2: "Heizung", 3: "Ofen"}
ds = DummyDS()

def sliding_window_segments(signal, frame_size, hop_size):
    segments = []
    for start in range(0, signal.size(1) - frame_size + 1, hop_size):
        end = start + frame_size
        segments.append(signal[:, start:end])
    return segments

device = torch.device("mps") # 或者 "cuda" 如果你有可用的GPU
model = AudioClassifier()
model.to(device)
model.eval()

sample_path = "evaluation_data/2_Florian_Heizung_aus.mp3" #请替换为你的路径
aud = AudioUtil.open(sample_path)
melspec = AudioUtil.spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=320)

frame_size = 44
hop_size = 1
segments = sliding_window_segments(melspec, frame_size, hop_size)

predictions = []
idx2label = ds.idx2label

for segment in segments:
    segment = segment.unsqueeze(0)
    segment = normalize(segment)
    with torch.no_grad():
        output = model(segment.to(device))
        predicted_label_idx = torch.argmax(output, dim=1).cpu().item()
        predictions.append(idx2label[predicted_label_idx])

counter = Counter(predictions)
print(counter)
most_common_prediction = counter.most_common(1)[0][0]

print(f"Predicted label for the long audio: {most_common_prediction}")

运行这段代码得到的结果是 Counter({'Heizung': 14, 'Ofen': 7}),预测的长音频标签是:Heizung。但实际上,音频中只包含 "Heizung" 这个词。

有人能给我一些建议或解决这个问题的办法吗? 我尝试过使用滑动窗口方法,但效果不如预期。

一、 问题原因分析

问题的核心在于训练数据和测试数据在时长上的不匹配。训练数据都是 1.1 秒的音频片段,而测试数据是 6-24 秒的长音频。直接用滑动窗口切割长音频并进行预测,会带来以下问题:

  1. 上下文缺失: 1.1 秒的窗口可能无法捕捉到完整的语音命令,尤其是当命令持续时间超过 1.1 秒时,会导致模型基于不完整的信息进行预测。
  2. 噪声影响: 长音频中可能包含静音、背景噪声等,这些非语音片段会被模型误判为某个类别。
  3. 边界效应: 滑动窗口在起始和结束位置可能会切割到语音命令的开头或结尾,导致识别错误。
  4. 单一投票局限: 仅仅依赖所有窗口中的多数投票结果可能会忽略时间序列上的信息,而且很容易受到连续误判的影响。

二、 解决方案

针对以上问题,我们可以尝试以下几种解决方案:

1. 平滑预测结果 (Smoothing Predictions)

  • 原理: 对连续窗口的预测结果进行平滑处理,减少噪声和误判的影响。可以采用滑动平均或中值滤波等方法。
  • 代码示例:
import numpy as np

def smooth_predictions(predictions, window_size=5):
    """
    使用滑动平均平滑预测结果。
    """
    smoothed_predictions = np.convolve(predictions, np.ones(window_size)/window_size, mode='valid')
    return smoothed_predictions

# 将预测的标签转换为数值表示,例如使用 label encoding
numerical_predictions = [list(ds.idx2label.keys())[list(ds.idx2label.values()).index(p)] for p in predictions]
smoothed_predictions = smooth_predictions(numerical_predictions, window_size=5) # 5 可以调整

# 将平滑后的数值预测转换回标签
smoothed_labels = [ds.idx2label[int(round(p))] for p in smoothed_predictions]

counter = Counter(smoothed_labels)
print(counter)
most_common_prediction = counter.most_common(1)[0][0]

print(f"Smoothed predicted label for the long audio: {most_common_prediction}")
  • 改进 : 使用中值滤波可以更有效地抑制孤立的错误预测。
    from scipy.signal import medfilt
    
    def smooth_predictions_median(predictions, window_size=5):
        numerical_predictions = [list(ds.idx2label.keys())[list(ds.idx2label.values()).index(p)] for p in predictions]
        smoothed_predictions = medfilt(numerical_predictions, kernel_size=window_size)
        return smoothed_predictions
    

2. 基于置信度的阈值过滤 (Confidence Thresholding)

  • 原理: 模型预测时会给出每个类别的概率(置信度)。可以设定一个阈值,只保留置信度高于阈值的预测结果,过滤掉低置信度的预测,以减少误判。
  • 代码示例:
confidence_threshold = 0.7  # 可调整
filtered_predictions = []

for segment in segments:
    segment = segment.unsqueeze(0)
    segment = normalize(segment)
    with torch.no_grad():
        output = model(segment.to(device))
        probabilities = torch.nn.functional.softmax(output, dim=1) # 获取概率
        max_prob, predicted_label_idx = torch.max(probabilities, dim=1)
        max_prob = max_prob.cpu().item()
        predicted_label_idx = predicted_label_idx.cpu().item()
        
        if max_prob > confidence_threshold:
            filtered_predictions.append(ds.idx2label[predicted_label_idx])

counter = Counter(filtered_predictions)
print(counter)
if(len(counter) > 0):
    most_common_prediction = counter.most_common(1)[0][0]
    print(f"Predicted label for the long audio (with thresholding): {most_common_prediction}")
else:
    print("No predictions above the confidence threshold.")

3. 集成HMM/CRF (隐马尔可夫模型/条件随机场) (进阶)

  • 原理: 考虑音频信号的时序特性。HMM 和 CRF 都是处理序列数据的概率模型,可以将音频帧序列的预测结果作为观测序列,通过 HMM/CRF 模型进行建模,得到更合理的标签序列。这个方案对序列标注任务比较合适.

  • 这个实现比较复杂,可能超出一般大学项目难度, 不再这里展示具体的实现代码了。但是我可以告诉你大致的步骤。

    • HMM

      1. 定义状态 : 对应你的音频类别 (e.g., "Radio", "Küche"). 还要加一个 "silence" 或 "background" 状态。
      2. 估计转移概率 : 根据训练数据或先验知识估计状态之间的转移概率 (比如, "Radio" 后面接 "aus" 的概率)。
      3. 估计发射概率 : 根据模型的输出 (softmax 后的概率) 估计每个状态下观察到特定输出的概率。
      4. 维特比算法 : 使用维特比算法 (Viterbi algorithm) 解码最可能的状态序列,也就是最终的标签序列。
    • CRF :

      1. 特征函数 : 定义特征函数,这些函数可以考虑当前帧的预测、前后帧的预测以及其他上下文信息。
      2. 训练 CRF : 使用标注好的长音频数据训练 CRF 模型 (这需要你对长音频进行标注,比较麻烦)。
      3. 解码 : 使用训练好的 CRF 模型对新的长音频进行解码,得到标签序列。

4. 加权投票

  • 原理:
    • 不同窗口位置对于整个音频的贡献可能是不一样的。对靠近中间的窗口可以赋予更大的权重,对于边缘的窗口赋予较小的权重。
    • 还可以根据预测概率(置信度)调整权重,对于置信度高的赋予更大权重。
  • 代码实例:
weighted_predictions = []

for i, segment in enumerate(segments):
    segment = segment.unsqueeze(0)
    segment = normalize(segment)
    with torch.no_grad():
        output = model(segment.to(device))
        probabilities = torch.nn.functional.softmax(output, dim=1)  # 获取概率
        max_prob, predicted_label_idx = torch.max(probabilities, dim=1)
        max_prob = max_prob.cpu().item()
        predicted_label_idx = predicted_label_idx.cpu().item()
         # 高斯权重
        center = len(segments) / 2
        weight = np.exp(-((i - center) ** 2) / (2 * (len(segments) / 4) **  2))
        #还可以乘上置信度
        weight = weight * max_prob


        weighted_predictions.append((ds.idx2label[predicted_label_idx], weight))

#统计加权后的投票结果
weighted_counts = {}
for label, weight in weighted_predictions:
    if label not in weighted_counts:
        weighted_counts[label] = 0
    weighted_counts[label] += weight

# 获取最高权重的标签
if len(weighted_counts)>0:
    most_likely_label = max(weighted_counts, key=weighted_counts.get)
    print(f"Predicted label: {most_likely_label}")
else:
    print("No prediction")

5. 调整滑动窗口参数

  • 原理 : 不同的hop_size, 会有不同的结果,小的hop_size会导致更多的重叠, 捕获的信息更多, 但是计算成本更高, 反之,大的hop_size 会减少计算量,可能会漏掉部分信息。
  • 可以试验不同的 hop_size, 和frame_size。 比如hop_size 设置为帧长度的一半(重叠50%)等等。
    • 需要注意的是, 改变frame_size 可能需要重新训练你的模型。

三、安全建议

  • 如果是应用到真实的场景, 尽量使用 HTTPS 等安全协议来传输音频数据,避免数据泄露。
  • 对于用户上传的音频数据,要进行安全检查,防止恶意音频。
  • 不要在生产环境中直接使用示例代码。
    通过调整或者综合采用以上方法,可以大幅改善预测结果。需要根据实际情况调整各个参数的数值,或者设计其他的组合应用策略,进行多次实验,找到最适合你的数据的处理方式。