Turns nonnegative integers (indexes) into dense vectors of fixed size.
Source:R/layers-core.R
layer_embedding.Rde.g. rbind(4L, 20L) \(\rightarrow\) rbind(c(0.25, 0.1), c(0.6, -0.2))
This layer can only be used on nonnegative integer inputs of a fixed range.
Usage
layer_embedding(
object,
input_dim,
output_dim,
embeddings_initializer = "uniform",
embeddings_regularizer = NULL,
embeddings_constraint = NULL,
mask_zero = FALSE,
weights = NULL,
lora_rank = NULL,
...
)Arguments
- object
Object to compose the layer with. A tensor, array, or sequential model.
- input_dim
Integer. Size of the vocabulary, i.e. maximum integer index + 1.
- output_dim
Integer. Dimension of the dense embedding.
- embeddings_initializer
Initializer for the
embeddingsmatrix (seekeras3::initializer_*).- embeddings_regularizer
Regularizer function applied to the
embeddingsmatrix (seekeras3::regularizer_*).- embeddings_constraint
Constraint function applied to the
embeddingsmatrix (seekeras3::constraint_*).- mask_zero
Boolean, whether or not the input value 0 is a special "padding" value that should be masked out. This is useful when using recurrent layers which may take variable length input. If this is
TRUE, then all subsequent layers in the model need to support masking or an exception will be raised. Ifmask_zerois set toTRUE, as a consequence, index 0 cannot be used in the vocabulary (input_dimshould equal size of vocabulary + 1).- weights
Optional floating-point matrix of size
(input_dim, output_dim). The initial embeddings values to use.- lora_rank
Optional integer. If set, the layer's forward pass will implement LoRA (Low-Rank Adaptation) with the provided rank. LoRA sets the layer's embeddings matrix to non-trainable and replaces it with a delta over the original matrix, obtained via multiplying two lower-rank trainable matrices. This can be useful to reduce the computation cost of fine-tuning large embedding layers. You can also enable LoRA on an existing
Embeddinglayer instance by callinglayer$enable_lora(rank).- ...
For forward/backward compatability.
Value
The return value depends on the value provided for the first argument.
If object is:
a
keras_model_sequential(), then the layer is added to the sequential model (which is modified in place). To enable piping, the sequential model is also returned, invisibly.a
keras_input(), then the output tensor from callinglayer(input)is returned.NULLor missing, then aLayerinstance is returned.
Example
model <- keras_model_sequential() |>
layer_embedding(1000, 64)
# The model will take as input an integer matrix of size (batch,input_length),
# and the largest integer (i.e. word index) in the input
# should be no larger than 999 (vocabulary size).
# Now model$output_shape is (NA, 10, 64), where `NA` is the batch
# dimension.
input_array <- random_integer(shape = c(32, 10), minval = 0, maxval = 1000)
model |> compile('rmsprop', 'mse')
output_array <- model |> predict(input_array, verbose = 0)
dim(output_array) # (32, 10, 64)Methods
-
enable_lora( rank, a_initializer = 'he_uniform', b_initializer = 'zeros' ) -
quantize(mode, type_check = TRUE) -
quantized_build(input_shape, mode) -
quantized_call(...)
See also
Other core layers: layer_dense() layer_einsum_dense() layer_identity() layer_lambda() layer_masking()
Other layers: Layer() layer_activation() layer_activation_elu() layer_activation_leaky_relu() layer_activation_parametric_relu() layer_activation_relu() layer_activation_softmax() layer_activity_regularization() layer_add() layer_additive_attention() layer_alpha_dropout() layer_attention() layer_aug_mix() layer_auto_contrast() layer_average() layer_average_pooling_1d() layer_average_pooling_2d() layer_average_pooling_3d() layer_batch_normalization() layer_bidirectional() layer_category_encoding() layer_center_crop() layer_concatenate() layer_conv_1d() layer_conv_1d_transpose() layer_conv_2d() layer_conv_2d_transpose() layer_conv_3d() layer_conv_3d_transpose() layer_conv_lstm_1d() layer_conv_lstm_2d() layer_conv_lstm_3d() layer_cropping_1d() layer_cropping_2d() layer_cropping_3d() layer_cut_mix() layer_dense() layer_depthwise_conv_1d() layer_depthwise_conv_2d() layer_discretization() layer_dot() layer_dropout() layer_einsum_dense() layer_equalization() layer_feature_space() layer_flatten() layer_flax_module_wrapper() layer_gaussian_dropout() layer_gaussian_noise() layer_global_average_pooling_1d() layer_global_average_pooling_2d() layer_global_average_pooling_3d() layer_global_max_pooling_1d() layer_global_max_pooling_2d() layer_global_max_pooling_3d() layer_group_normalization() layer_group_query_attention() layer_gru() layer_hashed_crossing() layer_hashing() layer_identity() layer_integer_lookup() layer_jax_model_wrapper() layer_lambda() layer_layer_normalization() layer_lstm() layer_masking() layer_max_num_bounding_boxes() layer_max_pooling_1d() layer_max_pooling_2d() layer_max_pooling_3d() layer_maximum() layer_mel_spectrogram() layer_minimum() layer_mix_up() layer_multi_head_attention() layer_multiply() layer_normalization() layer_permute() layer_rand_augment() layer_random_brightness() layer_random_color_degeneration() layer_random_color_jitter() layer_random_contrast() layer_random_crop() layer_random_erasing() layer_random_flip() layer_random_gaussian_blur() layer_random_grayscale() layer_random_hue() layer_random_invert() layer_random_perspective() layer_random_posterization() layer_random_rotation() layer_random_saturation() layer_random_sharpness() layer_random_shear() layer_random_translation() layer_random_zoom() layer_repeat_vector() layer_rescaling() layer_reshape() layer_resizing() layer_rms_normalization() layer_rnn() layer_separable_conv_1d() layer_separable_conv_2d() layer_simple_rnn() layer_solarization() layer_spatial_dropout_1d() layer_spatial_dropout_2d() layer_spatial_dropout_3d() layer_spectral_normalization() layer_stft_spectrogram() layer_string_lookup() layer_subtract() layer_text_vectorization() layer_tfsm() layer_time_distributed() layer_torch_module_wrapper() layer_unit_normalization() layer_upsampling_1d() layer_upsampling_2d() layer_upsampling_3d() layer_zero_padding_1d() layer_zero_padding_2d() layer_zero_padding_3d() rnn_cell_gru() rnn_cell_lstm() rnn_cell_simple() rnn_cells_stack()