Skip to contents

This layer can perform einsum calculations of arbitrary dimensionality.

Usage

layer_einsum_dense(
  object,
  equation,
  output_shape,
  activation = NULL,
  bias_axes = NULL,
  kernel_initializer = "glorot_uniform",
  bias_initializer = "zeros",
  kernel_regularizer = NULL,
  bias_regularizer = NULL,
  kernel_constraint = NULL,
  bias_constraint = NULL,
  lora_rank = NULL,
  ...
)

Arguments

object

Object to compose the layer with. A tensor, array, or sequential model.

equation

An equation describing the einsum to perform. This equation must be a valid einsum string of the form ab,bc->ac, ...ab,bc->...ac, or ab...,bc->ac... where 'ab', 'bc', and 'ac' can be any valid einsum axis expression sequence.

output_shape

The expected shape of the output tensor (excluding the batch dimension and any dimensions represented by ellipses). You can specify NA or NULL for any dimension that is unknown or can be inferred from the input shape.

activation

Activation function to use. If you don't specify anything, no activation is applied (that is, a "linear" activation: a(x) = x).

bias_axes

A string containing the output dimension(s) to apply a bias to. Each character in the bias_axes string should correspond to a character in the output portion of the equation string.

kernel_initializer

Initializer for the kernel weights matrix.

bias_initializer

Initializer for the bias vector.

kernel_regularizer

Regularizer function applied to the kernel weights matrix.

bias_regularizer

Regularizer function applied to the bias vector.

kernel_constraint

Constraint function applied to the kernel weights matrix.

bias_constraint

Constraint function applied to the bias vector.

lora_rank

Optional integer. If set, the layer's forward pass will implement LoRA (Low-Rank Adaptation) with the provided rank. LoRA sets the layer's kernel to non-trainable and replaces it with a delta over the original kernel, obtained via multiplying two lower-rank trainable matrices (the factorization happens on the last dimension). This can be useful to reduce the computation cost of fine-tuning large dense layers. You can also enable LoRA on an existing EinsumDense layer by calling layer$enable_lora(rank).

...

Base layer keyword arguments, such as name and dtype.

Value

The return value depends on the value provided for the first argument. If object is:

  • a keras_model_sequential(), then the layer is added to the sequential model (which is modified in place). To enable piping, the sequential model is also returned, invisibly.

  • a keras_input(), then the output tensor from calling layer(input) is returned.

  • NULL or missing, then a Layer instance is returned.

Examples

Biased dense layer with einsums

This example shows how to instantiate a standard Keras dense layer using einsum operations. This example is equivalent to layer_Dense(64, use_bias=TRUE).

input <- layer_input(shape = c(32))
output <- input |>
  layer_einsum_dense("ab,bc->ac",
                     output_shape = 64,
                     bias_axes = "c")
output # shape(NA, 64)

## <KerasTensor shape=(None, 64), dtype=float32, sparse=False, name=keras_tensor_1>

Applying a dense layer to a sequence

This example shows how to instantiate a layer that applies the same dense operation to every element in a sequence. Here, the output_shape has two values (since there are two non-batch dimensions in the output); the first dimension in the output_shape is NA, because the sequence dimension b has an unknown shape.

input <- layer_input(shape = c(32, 128))
output <- input |>
  layer_einsum_dense("abc,cd->abd",
                     output_shape = c(NA, 64),
                     bias_axes = "d")
output  # shape(NA, 32, 64)

## <KerasTensor shape=(None, None, 64), dtype=float32, sparse=False, name=keras_tensor_3>

Applying a dense layer to a sequence using ellipses

This example shows how to instantiate a layer that applies the same dense operation to every element in a sequence, but uses the ellipsis notation instead of specifying the batch and sequence dimensions.

Because we are using ellipsis notation and have specified only one axis, the output_shape arg is a single value. When instantiated in this way, the layer can handle any number of sequence dimensions - including the case where no sequence dimension exists.

input <- layer_input(shape = c(32, 128))
output <- input |>
  layer_einsum_dense("...x,xy->...y",
                     output_shape = 64,
                     bias_axes = "y")

output  # shape(NA, 32, 64)

## <KerasTensor shape=(None, 32, 64), dtype=float32, sparse=False, name=keras_tensor_5>

Methods

  • enable_lora(
      rank,
      a_initializer = 'he_uniform',
      b_initializer = 'zeros'
    )

  • quantize(mode, type_check = TRUE)

Readonly properties:

  • kernel

See also

Other core layers:
layer_dense()
layer_embedding()
layer_identity()
layer_lambda()
layer_masking()

Other layers:
Layer()
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_embedding()
layer_feature_space()
layer_flatten()
layer_flax_module_wrapper()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashed_crossing()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_jax_model_wrapper()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_mel_spectrogram()
layer_minimum()
layer_multi_head_attention()
layer_multiply()
layer_normalization()
layer_permute()
layer_random_brightness()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_rotation()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_tfsm()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()