from typing import Optional, Tuple, Union
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .abstract_model import AbstractModel
from .timm_wrapper import TimmModel
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
from speechbrain.lobes.features import Leaf as LeafSb
[docs]
class LEAFNet(AbstractModel):
def __init__(
self,
output_dim: int,
leaf_filters: int = 40,
kernel_size: int = 25,
stride: float = 0.0625,
window_stride: int = 10,
padding_kernel_size: int = 25,
sample_rate: int = 16000,
min_freq: int = 60,
max_freq: int = 7800,
efficientnet_type: str = "efficientnet_b0",
mode: str = "interspeech",
initialization: str = "mel",
generator_seed: int = 42,
transfer: bool = False,
) -> None:
"""EfficientNet with LEAF-Is frontend.
Used to reproduce work from:
https://www.isca-archive.org/interspeech_2023/meng23c_interspeech.html
Also see original LEAF and PCEN papers (c.f. speechbrain).
We take and slightly adapt the LEAF frontend implementation from:
https://github.com/Hanyu-Meng/Adapting-LEAF
Args:
output_dim: Output dimension.
leaf_filters: Number of LEAF filterbanks to train. Defaults to 40.
kernel_size: Size of kernels applied by LEAF (in ms).
Defaults to 25.
stride: Stride of LEAF (in ms). Defaults to 0.0625.
window_stride: Stride of lowpass filter (in ms). Defaults to 10.
padding_kernel_size: Size of lowpass filter (in ms).
Defaults to 25.
sample_rate: Used to compute LEAF params. Defaults to 16000.
min_freq: Minimum freq analyzed by LEAF. Defaults to 60.
max_freq: Maximum freq analyzed by LEAF. Defaults to 7800.
efficientnet_type: EfficientNet type to use from timm.
Defaults to "efficientnet_b0".
mode: Implementation according to "interspeech" paper (Meng et al.)
or "speech_brain". Defaults to "interspeech".
initialization: Filterbank initialisation in ["mel", "bark",
"linear-constant", "constant", "uniform", "zeros"].
Defaults to "mel".
generator_seed: Seed for random generator. Defaults to 42.
transfer: Whether to use EfficientNet weights from ImageNet.
Defaults to False.
Raises:
ValueError: If efficientnet_type is not supported.
ValueError: If mode is not supported.
"""
super().__init__(output_dim)
# convert LEAF params from ms to samples
self.mode = mode
self.initialization = initialization
self.leaf_filters = leaf_filters
self.min_freq = min_freq
self.max_freq = max_freq
self.sample_rate = sample_rate
self.generator_seed = generator_seed
self.kernel_size = kernel_size
self.stride = stride
self.window_stride = window_stride
self.padding_kernel_size = padding_kernel_size
self.efficientnet_type = efficientnet_type
self.transfer = transfer
kernel_size_sample = int(sample_rate * kernel_size / 1000)
kernel_size_sample += 1 - (kernel_size % 2) # make odd
# convert pooling params from ms to samples
if mode == "interspeech":
self.leaf = LeafIs(
n_filters=leaf_filters,
min_freq=min_freq,
max_freq=max_freq,
sample_rate=sample_rate,
window_len=padding_kernel_size,
window_stride=window_stride,
)
elif mode == "speech_brain":
self.leaf = LeafSb(
out_channels=leaf_filters,
in_channels=1,
sample_rate=sample_rate,
min_freq=min_freq,
use_pcen=True,
learnable_pcen=True,
use_legacy_complex=True,
skip_transpose=True,
)
else:
raise ValueError(
"Only options 'interspeech' and 'speech_brain' are available"
f" for mode, but got mode='{mode}'."
)
self._initialise_filterbank()
if not self.efficientnet_type.startswith("efficientnet_"):
raise ValueError(
"Only EfficientNet models are supported, but got"
f" efficientnet_type='{efficientnet_type}'."
)
self.classifier = TimmModel(
output_dim=self.output_dim,
timm_name=self.efficientnet_type,
transfer=self.transfer,
in_chans=1,
)
def embeddings(self, x: torch.Tensor) -> torch.Tensor:
return self.leaf(x)
def forward(self, x: torch.Tensor):
x = self.leaf(x)
x = x.unsqueeze(1)
x = self.classifier(x)
return x
def _initialise_filterbank(self) -> None:
if self.initialization == "mel":
return
elif self.initialization == "linear-constant":
center_frequencies = torch.linspace(
self.min_freq, self.max_freq, self.leaf_filters
)
bandwidths = torch.ones((self.leaf_filters,)) * np.exp(
(np.log(self.max_freq) - np.log(self.min_freq)) / 2.5
+ np.log(self.min_freq)
)
elif self.initialization == "zeros":
center_frequencies = torch.zeros((self.leaf_filters,))
bandwidths = torch.ones((self.leaf_filters,)) * np.exp(
np.log(self.max_freq - self.min_freq) / 2
)
elif self.initialization == "constant":
center_frequencies = torch.ones((self.leaf_filters,)) * np.exp(
(np.log(self.max_freq) - np.log(self.min_freq)) / 2
+ np.log(self.min_freq)
)
bandwidths = torch.ones((self.leaf_filters,)) * np.exp(
(np.log(self.max_freq) - np.log(self.min_freq)) / 2.5
+ np.log(self.min_freq)
)
elif self.initialization == "uniform":
generator = torch.Generator()
generator.manual_seed(self.generator_seed)
center_frequencies = (
torch.rand((self.leaf_filters,), generator=generator)
* (self.max_freq - self.min_freq)
+ self.min_freq
)
center_frequencies = torch.sort(center_frequencies).values
# Estimation
bandwidths = (
torch.rand((self.leaf_filters,), generator=generator)
* (self.max_freq - self.min_freq)
/ 10
+ self.min_freq
)
elif self.initialization == "bark":
center_frequencies, bandwidths = bark_scale_filterbank(
self.min_freq, self.max_freq, self.leaf_filters
)
else:
raise ValueError(
"Only options 'mel', 'linear-constant', 'constant', 'uniform', "
f"'bark', and 'zeros' are available for initialization "
f", but got initialization='{self.initialization}'."
)
# adjustment for sample_rate
center_frequencies *= 2 * np.pi / self.sample_rate
# bandwiths to sigmas
bandwidths = (self.sample_rate / 2.0) / bandwidths
if self.mode == "interspeech":
center_freqs_param = "filterbank.center_freqs"
bandwith_param = "filterbank.bandwidths"
self.leaf.state_dict()[center_freqs_param].copy_(
center_frequencies
)
self.leaf.state_dict()[bandwith_param].copy_(bandwidths)
elif self.mode == "speech_brain":
filterbank_param = "complex_conv.kernel"
self.leaf.state_dict()[filterbank_param][:, 0].copy_(
center_frequencies
)
self.leaf.state_dict()[filterbank_param][:, 1].copy_(bandwidths)
def hz_to_bark(
f: Union[int, float, np.ndarray],
) -> Union[int, float, np.ndarray]:
"""Convert frequency from Hz to Bark scale according to Traunmüller.
Args:
f: Frequency in Hz.
Returns:
Frequency in Bark.
"""
b = 26.81 * f / (1960 + f) - 0.53
return b
def bark_to_hz(
b: Union[int, float, np.ndarray],
) -> Union[int, float, np.ndarray]:
"""Approximate conversion from Bark to Hz (inverse of the above).
Ajdusted centers for reconversion according to Traunmüller.
Args:
b: Frequency in Bark.
Returns:
Frequency in Hz.
"""
z_1 = b[b < 2] + 0.15 * (2 - b[b < 2])
z_2 = b[b >= 2]
z_2 = z_2[z_2 <= 20.1]
z_3 = b[b > 20.1] + 0.22 * (b[b > 20.1] - 20.1)
z = np.hstack((z_1, z_2, z_3))
f = 1960 * (z + 0.53) / (26.28 - z)
return f # 600 * np.sinh(z / 6)
def bark_scale_filterbank(
min_freq: Union[int, float],
max_freq: Union[int, float],
num_bands: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate center frequencies and bandwidths for a Bark scale filterbank.
Args:
min_freq: Minimum frequency in Hz.
max_freq: Maximum frequency in Hz.
num_bands: Number of bands.
Returns:
Tuple of center frequencies and bandwidths.
"""
# Convert min and max frequencies to Bark scale
min_bark = hz_to_bark(min_freq)
max_bark = hz_to_bark(max_freq)
# Evenly distribute bands in Bark scale
bark_centers = np.linspace(
min_bark, max_bark + (max_bark - min_bark) / num_bands, num_bands + 1
)
# Convert center frequencies back to Hz
center_freqs = bark_to_hz(bark_centers)
bandwidths = np.diff(center_freqs)
return torch.Tensor(center_freqs[:-1]), torch.Tensor(bandwidths)
def mel_filter_params(
n_filters: int,
min_freq: float,
max_freq: float,
sample_rate: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Analytically calculate the center frequencies and sigmas of a mel
filter bank.
Args:
n_filters: Number of filters for the filterbank.
min_freq: Minimum cutoff for the frequencies.
max_freq: Maximum cutoff for the frequencies.
sample_rate: Sample rate to use for the calculation.
Returns:
Center frequencies and sigmas.
"""
min_mel = 1127 * np.log1p(min_freq / 700.0)
max_mel = 1127 * np.log1p(max_freq / 700.0)
peaks_mel = torch.linspace(min_mel, max_mel, n_filters + 2)
peaks_hz = 700 * (torch.expm1(peaks_mel / 1127))
center_freqs = peaks_hz[1:-1] * (2 * np.pi / sample_rate)
bandwidths = peaks_hz[2:] - peaks_hz[:-2]
sigmas = (sample_rate / 2.0) / bandwidths
return center_freqs, sigmas
def gabor_filters(
size: int, center_freqs: torch.Tensor, sigmas: torch.Tensor
) -> torch.Tensor:
"""Calculate a gabor function from given center frequencies and bandwidths
that can be used as kernel/filters for an 1D convolution.
Args:
size: Kernel/filter size.
center_freqs: Center frequencies.
sigmas: Bandwidths.
Returns:
Kernel/filter that can be used for an 1D convolution.
"""
t = torch.arange(-(size // 2), (size + 1) // 2, device=center_freqs.device)
denominator = 1.0 / (np.sqrt(2 * np.pi) * sigmas)
gaussian = torch.exp(torch.outer(1.0 / (2.0 * sigmas**2), -(t**2)))
sinusoid = torch.exp(1j * torch.outer(center_freqs, t))
return denominator[:, np.newaxis] * sinusoid * gaussian
def gauss_windows(size: int, sigmas: torch.Tensor) -> torch.Tensor:
"""Calculate a gaussian lowpass function from given bandwidths that can be
used as kernel/filter for an 1D convolution.
Args:
size: Kernel/filter size.
sigmas: Bandwidths.
Returns:
Kernel/filter that can be used for an 1D convolution.
"""
t = torch.arange(0, size, device=sigmas.device)
numerator = t * (2 / (size - 1)) - 1
return torch.exp(-0.5 * (numerator / sigmas[:, np.newaxis]) ** 2)
class GaborFilterbank(nn.Module):
def __init__(
self,
n_filters: int,
min_freq: float,
max_freq: float,
sample_rate: int,
filter_size: int,
pool_size: int,
pool_stride: int,
pool_init: float = 0.4,
) -> None:
"""Torch module that functions as a gabor filterbank.
Initializes n_filters center frequencies and bandwidths that are based
on a mel filterbank. The parameters are used to calculate Gabor filters
for a 1D convolution over the input signal. The squared modulus is
taken from the results. To reduce the temporal resolution a gaussian
lowpass filter is calculated from pooling_widths, which are used to
perform a pooling operation. The center frequencies, bandwidths and
pooling_widths are learnable parameters.
Args:
n_filters: Number of filters.
min_freq: Minimum frequency for the mel filterbank initialization.
max_freq: Maximum frequency for the mel filterbank initialization.
sample_rate: Sample rate for the mel filterbank initialization.
filter_size: Size of the kernels/filters for gabor convolution.
pool_size: Size of the kernels/filters for pooling convolution.
pool_stride: Stride of the pooling convolution.
pool_init: Initial value for the gaussian lowpass function.
Defaults to 0.4.
"""
super(GaborFilterbank, self).__init__()
self.n_filters = n_filters
self.filter_size = filter_size
self.pool_size = pool_size
self.pool_stride = pool_stride
center_freqs, bandwidths = mel_filter_params(
n_filters, min_freq, max_freq, sample_rate
)
self.center_freqs = nn.Parameter(center_freqs)
self.bandwidths = nn.Parameter(bandwidths)
self.pooling_widths = nn.Parameter(
torch.full((n_filters,), float(pool_init))
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# compute filters
center_freqs = self.center_freqs.clamp(min=0.0, max=np.pi)
z = np.sqrt(2 * np.log(2)) / np.pi
bandwidths = self.bandwidths.clamp(min=4 * z, max=self.filter_size * z)
filters = gabor_filters(self.filter_size, center_freqs, bandwidths)
filters = torch.cat((filters.real, filters.imag), dim=0).unsqueeze(1)
# convolve with filters
x = F.conv1d(x, filters, padding=self.filter_size // 2)
# compute squared modulus
x = x**2
x = x[:, : self.n_filters] + x[:, self.n_filters :]
# compute pooling windows
pooling_widths = self.pooling_widths.clamp(
min=2.0 / self.pool_size, max=0.5
)
windows = gauss_windows(self.pool_size, pooling_widths).unsqueeze(1)
# apply temporal pooling
x = F.conv1d(
x,
windows,
stride=self.pool_stride,
padding=self.filter_size // 2,
groups=self.n_filters,
)
return x
class PCEN(nn.Module):
def __init__(
self,
num_bands: int,
s: float = 0.025,
alpha: float = 1.0,
delta: float = 1.0,
r: float = 1.0,
eps: float = 1e-6,
learn_logs: bool = False,
clamp: Optional[float] = None,
) -> None:
"""Trainable PCEN (Per-Channel Energy Normalization) layer.
.. math::
Y = (\\frac{X}{(\\epsilon + M)^\\alpha} + \\delta)^r - \\delta^r
M_t = (1 - s) M_{t - 1} + s X_t
Args:
num_bands: Number of frequency bands (before last input dimension).
s: Initial value for :math:`s`.
alpha: Initial value for :math:`alpha`
delta: Initial value for :math:`delta`
r: Initial value for :math:`r`
eps: Value for :math:`eps`
learn_logs: If false-ish, instead of learning the logarithm of each
parameter (as in the PCEN paper), learn the inverse of :math:`r`
and all other parameters directly (as in the LEAF paper).
clamp: If given, clamps the input to the given minimum value before
applying PCEN.
"""
super(PCEN, self).__init__()
if learn_logs:
# learns logarithm of each parameter
s = np.log(s)
alpha = np.log(alpha)
delta = np.log(delta)
r = np.log(r)
else:
# learns inverse of r, and all other parameters directly
r = 1.0 / r
self.learn_logs = learn_logs
self.s = nn.Parameter(torch.full((num_bands,), float(s)))
self.alpha = nn.Parameter(torch.full((num_bands,), float(alpha)))
self.delta = nn.Parameter(torch.full((num_bands,), float(delta)))
self.r = nn.Parameter(torch.full((num_bands,), float(r)))
self.eps = torch.as_tensor(eps)
self.clamp = clamp
def forward(self, x: torch.Tensor) -> torch.Tensor:
# clamp if needed
if self.clamp is not None:
x = x.clamp(min=self.clamp)
# prepare parameters
if self.learn_logs:
# learns logarithm of each parameter
s = self.s.exp()
alpha = self.alpha.exp()
delta = self.delta.exp()
r = self.r.exp()
else:
# learns inverse of r, and all other parameters directly
s = self.s
alpha = self.alpha.clamp(max=1)
delta = self.delta.clamp(min=0) # unclamped in original LEAF impl.
r = 1.0 / self.r.clamp(min=1)
# broadcast over channel dimension
alpha = alpha[:, np.newaxis]
delta = delta[:, np.newaxis]
r = r[:, np.newaxis]
# compute smoother
smoother = [x[..., 0]] # initialize the smoother with the first frame
for frame in range(1, x.shape[-1]):
smoother.append((1 - s) * smoother[-1] + s * x[..., frame])
smoother = torch.stack(smoother, -1)
# stable reformulation due to Vincent Lostanlen; original formula was:
# return (input / (self.eps + smoother)**alpha + delta)**r - delta**r
smoother = torch.exp(
-alpha * (torch.log(self.eps) + torch.log1p(smoother / self.eps))
)
return (x * smoother + delta) ** r - delta**r
class LeafIs(nn.Module):
def __init__(
self,
n_filters: int = 40,
min_freq: float = 60.0,
max_freq: float = 7800.0,
sample_rate: int = 16000,
window_len: float = 25.0,
window_stride: float = 10.0,
compression: Optional[torch.nn.Module] = None,
) -> None:
"""LEAF frontend, a learnable front-end that takes an audio waveform as
input and outputs a learnable spectral representation.
Initially approximates the computation of standard mel-filterbanks.
Args:
n_filters: Number of filters. Defaults to 40.
min_freq: Minimum frequency. Defaults to 60.0.
max_freq: Maximum frequency. Defaults to 7800.0.
sample_rate: Sample Rate for filterbanc initialization.
Defaults to 16000.
window_len: Kernel/filter size of the convolutions in ms.
Defaults to 25.0.
window_stride: Stride used for the pooling convolution in ms.
Defaults to 10.0.
compression: Compression function. If None, PCEN is used.
Defaults to None.
"""
super(LeafIs, self).__init__()
# convert window sizes from milliseconds to samples
window_size = int(sample_rate * window_len / 1000)
window_size += 1 - (window_size % 2) # make odd
window_stride = int(sample_rate * window_stride / 1000)
self.filterbank = GaborFilterbank(
n_filters,
min_freq,
max_freq,
sample_rate,
filter_size=window_size,
pool_size=window_size,
pool_stride=window_stride,
)
self.compression = (
compression
if compression
else PCEN(
n_filters,
s=0.04,
alpha=0.96,
delta=2,
r=0.5,
eps=1e-12,
learn_logs=False,
clamp=1e-5,
)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
while x.ndim < 3:
x = x[:, np.newaxis]
x = self.filterbank(x)
x = self.compression(x)
return x