Skip to contents

A layer that computes Spectrograms of the input signal to produce a spectrogram. This layers by The layer computes Spectrograms based on Short-Time Fourier Transform (STFT) by utilizing convolution kernels, which allows parallelization on GPUs and trainable kernels for fine-tuning support. This layer allows different modes of output (e.g., log-scaled magnitude, phase, power spectral density, etc.) and provides flexibility in windowing, padding, and scaling options for the STFT calculation.

Usage

layer_stft_spectrogram(
  object,
  mode = "log",
  frame_length = 256L,
  frame_step = NULL,
  fft_length = NULL,
  window = "hann",
  periodic = FALSE,
  scaling = "density",
  padding = "valid",
  expand_dims = FALSE,
  data_format = NULL,
  ...
)

Arguments

object

Object to compose the layer with. A tensor, array, or sequential model.

mode

String, the output type of the spectrogram. Can be one of "log", "magnitude", "psd", "real", "imag", "angle", "stft". Defaults to "log".

frame_length

Integer, The length of each frame (window) for STFT in samples. Defaults to 256.

frame_step

Integer, the step size (hop length) between consecutive frames. If not provided, defaults to half the frame_length. Defaults to frame_length %/% 2.

fft_length

Integer, the size of frequency bins used in the Fast-Fourier Transform (FFT) to apply to each frame. Should be greater than or equal to frame_length. Recommended to be a power of two. Defaults to the smallest power of two that is greater than or equal to frame_length.

window

(String or array_like), the windowing function to apply to each frame. Can be "hann" (default), "hamming", or a custom window provided as an array_like.

periodic

Boolean, if TRUE, the window function will be treated as periodic. Defaults to FALSE.

scaling

String, type of scaling applied to the window. Can be "density", "spectrum", or None. Default is "density".

padding

String, padding strategy. Can be "valid" or "same". Defaults to "valid".

expand_dims

Boolean, if TRUE, will expand the output into spectrograms into two dimensions to be compatible with image models. Defaults to FALSE.

data_format

String, either "channels_last" or "channels_first". The ordering of the dimensions in the inputs. "channels_last" corresponds to inputs with shape (batch, height, width, channels) while "channels_first" corresponds to inputs with shape (batch, channels, height, weight). Defaults to "channels_last".

...

For forward/backward compatability.

Value

The return value depends on the value provided for the first argument. If object is:

  • a keras_model_sequential(), then the layer is added to the sequential model (which is modified in place). To enable piping, the sequential model is also returned, invisibly.

  • a keras_input(), then the output tensor from calling layer(input) is returned.

  • NULL or missing, then a Layer instance is returned.

Examples

Apply it as a non-trainable preprocessing layer on 3 audio tracks of 1 channel, 10 seconds and sampled at 16 kHz.

layer <- layer_stft_spectrogram(
  mode = 'log',
  frame_length = 256,
  frame_step = 128, # 50% overlap
  fft_length = 512,
  window = "hann",
  padding = "valid",
  trainable = FALSE # non-trainable, preprocessing only)
)
random_uniform(shape=c(3, 160000, 1)) |> layer() |> op_shape()

## shape(3, 1249, 257)

Apply it as a trainable processing layer on 3 stereo audio tracks of 2 channels, 10 seconds and sampled at 16 kHz. This is initialized as the non-trainable layer, but then can be trained jointly within a model.

layer <- layer_stft_spectrogram(
  mode = 'log',
  frame_length = 256,
  frame_step = 128,   # 50% overlap
  fft_length = 512,
  window = "hamming", # hamming windowing function
  padding = "same",   # padding to preserve the time dimension
  trainable = TRUE,   # trainable, this is the default in keras
)
random_uniform(shape=c(3, 160000, 2)) |> layer() |> op_shape()

## shape(3, 1250, 514)

Similar to the last example, but add an extra dimension so the output is an image to be used with image models. We apply this here on a signal of 3 input channels to output an image tensor, hence is directly applicable with an image model.

layer <- layer_stft_spectrogram(
  mode = 'log',
  frame_length = 256,
  frame_step = 128,
  fft_length = 512,
  padding = "same",
  expand_dims = TRUE  # this adds the extra dimension
)
random_uniform(shape=c(3, 160000, 3)) |> layer() |> op_shape()

## shape(3, 1250, 257, 3)

Raises

ValueError: If an invalid value is provided for "mode", "scaling", "padding", or other input arguments. TypeError: If the input data type is not one of "float16", "float32", or "float64".

Input Shape

A 3D tensor of shape (batch_size, time_length, input_channels), if data_format=="channels_last", and of shape (batch_size, input_channels, time_length) if data_format=="channels_first", where time_length is the length of the input signal, and input_channels is the number of input channels. The same kernels are applied to each channel independently.

Output Shape

If data_format=="channels_first" && !expand_dims, a 3D tensor: (batch_size, input_channels * freq_channels, new_time_length) If data_format=="channels_last" && !expand_dims, a 3D tensor: (batch_size, new_time_length, input_channels * freq_channels) If data_format=="channels_first" && expand_dims, a 4D tensor: (batch_size, input_channels, new_time_length, freq_channels) If data_format=="channels_last" && expand_dims, a 4D tensor: (batch_size, new_time_length, freq_channels, input_channels)

where new_time_length depends on the padding, and freq_channels is the number of FFT bins (fft_length %/% 2 + 1).

See also

Other audio preprocessing layers:
layer_mel_spectrogram()

Other preprocessing layers:
layer_auto_contrast()
layer_category_encoding()
layer_center_crop()
layer_discretization()
layer_equalization()
layer_feature_space()
layer_hashed_crossing()
layer_hashing()
layer_integer_lookup()
layer_max_num_bounding_boxes()
layer_mel_spectrogram()
layer_mix_up()
layer_normalization()
layer_rand_augment()
layer_random_brightness()
layer_random_color_degeneration()
layer_random_color_jitter()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_grayscale()
layer_random_hue()
layer_random_posterization()
layer_random_rotation()
layer_random_saturation()
layer_random_sharpness()
layer_random_shear()
layer_random_translation()
layer_random_zoom()
layer_rescaling()
layer_resizing()
layer_solarization()
layer_string_lookup()
layer_text_vectorization()

Other layers:
Layer()
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_auto_contrast()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_einsum_dense()
layer_embedding()
layer_equalization()
layer_feature_space()
layer_flatten()
layer_flax_module_wrapper()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashed_crossing()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_jax_model_wrapper()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_num_bounding_boxes()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_mel_spectrogram()
layer_minimum()
layer_mix_up()
layer_multi_head_attention()
layer_multiply()
layer_normalization()
layer_permute()
layer_rand_augment()
layer_random_brightness()
layer_random_color_degeneration()
layer_random_color_jitter()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_grayscale()
layer_random_hue()
layer_random_posterization()
layer_random_rotation()
layer_random_saturation()
layer_random_sharpness()
layer_random_shear()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_solarization()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_tfsm()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()