Bidirectional wrapper for RNNs.
Usage
layer_bidirectional(
object,
layer,
merge_mode = "concat",
weights = NULL,
backward_layer = NULL,
...
)Arguments
- object
Object to compose the layer with. A tensor, array, or sequential model.
- layer
RNNinstance, such aslayer_lstm()orlayer_gru(). It could also be aLayer()instance that meets the following criteria:Be a sequence-processing layer (accepts 3D+ inputs).
Have a
go_backwards,return_sequencesandreturn_stateattribute (with the same semantics as for theRNNclass).Have an
input_specattribute.Implement serialization via
get_config()andfrom_config(). Note that the recommended way to create new RNN layers is to write a custom RNN cell and use it withlayer_rnn(), instead of subclassing withLayer()directly. Whenreturn_sequencesisTRUE, the output of the masked timestep will be zero regardless of the layer's originalzero_output_for_maskvalue.
- merge_mode
Mode by which outputs of the forward and backward RNNs will be combined. One of
{"sum", "mul", "concat", "ave", NULL}. IfNULL, the outputs will not be combined, they will be returned as a list. Defaults to"concat".- weights
see description
- backward_layer
Optional
RNN, orLayer()instance to be used to handle backwards input processing. Ifbackward_layeris not provided, the layer instance passed as thelayerargument will be used to generate the backward layer automatically. Note that the providedbackward_layerlayer should have properties matching those of thelayerargument, in particular it should have the same values forstateful,return_states,return_sequences, etc. In addition,backward_layerandlayershould have differentgo_backwardsargument values. AValueErrorwill be raised if these requirements are not met.- ...
For forward/backward compatability.
Value
The return value depends on the value provided for the first argument.
If object is:
a
keras_model_sequential(), then the layer is added to the sequential model (which is modified in place). To enable piping, the sequential model is also returned, invisibly.a
keras_input(), then the output tensor from callinglayer(input)is returned.NULLor missing, then aLayerinstance is returned.
Call Arguments
The call arguments for this layer are the same as those of the
wrapped RNN layer. Beware that when passing the initial_state
argument during the call of this layer, the first half in the
list of elements in the initial_state list will be passed to
the forward RNN call and the last half in the list of elements
will be passed to the backward RNN call.
Note
instantiating a Bidirectional layer from an existing RNN layer
instance will not reuse the weights state of the RNN layer instance -- the
Bidirectional layer will have freshly initialized weights.
Examples
model <- keras_model_sequential(input_shape = c(5, 10)) %>%
layer_bidirectional(layer_lstm(units = 10, return_sequences = TRUE)) %>%
layer_bidirectional(layer_lstm(units = 10)) %>%
layer_dense(5, activation = "softmax")
model %>% compile(loss = "categorical_crossentropy",
optimizer = "rmsprop")
# With custom backward layer
forward_layer <- layer_lstm(units = 10, return_sequences = TRUE)
backward_layer <- layer_lstm(units = 10, activation = "relu",
return_sequences = TRUE, go_backwards = TRUE)
model <- keras_model_sequential(input_shape = c(5, 10)) %>%
bidirectional(forward_layer, backward_layer = backward_layer) %>%
layer_dense(5, activation = "softmax")
model %>% compile(loss = "categorical_crossentropy",
optimizer = "rmsprop")States
A Bidirectional layer instance has property states, which you can access
with layer$states. You can also reset states using reset_state()
See also
Other rnn layers: layer_conv_lstm_1d() layer_conv_lstm_2d() layer_conv_lstm_3d() layer_gru() layer_lstm() layer_rnn() layer_simple_rnn() layer_time_distributed() rnn_cell_gru() rnn_cell_lstm() rnn_cell_simple() rnn_cells_stack()
Other layers: Layer() layer_activation() layer_activation_elu() layer_activation_leaky_relu() layer_activation_parametric_relu() layer_activation_relu() layer_activation_softmax() layer_activity_regularization() layer_add() layer_additive_attention() layer_alpha_dropout() layer_attention() layer_aug_mix() layer_auto_contrast() layer_average() layer_average_pooling_1d() layer_average_pooling_2d() layer_average_pooling_3d() layer_batch_normalization() layer_category_encoding() layer_center_crop() layer_concatenate() layer_conv_1d() layer_conv_1d_transpose() layer_conv_2d() layer_conv_2d_transpose() layer_conv_3d() layer_conv_3d_transpose() layer_conv_lstm_1d() layer_conv_lstm_2d() layer_conv_lstm_3d() layer_cropping_1d() layer_cropping_2d() layer_cropping_3d() layer_cut_mix() layer_dense() layer_depthwise_conv_1d() layer_depthwise_conv_2d() layer_discretization() layer_dot() layer_dropout() layer_einsum_dense() layer_embedding() layer_equalization() layer_feature_space() layer_flatten() layer_flax_module_wrapper() layer_gaussian_dropout() layer_gaussian_noise() layer_global_average_pooling_1d() layer_global_average_pooling_2d() layer_global_average_pooling_3d() layer_global_max_pooling_1d() layer_global_max_pooling_2d() layer_global_max_pooling_3d() layer_group_normalization() layer_group_query_attention() layer_gru() layer_hashed_crossing() layer_hashing() layer_identity() layer_integer_lookup() layer_jax_model_wrapper() layer_lambda() layer_layer_normalization() layer_lstm() layer_masking() layer_max_num_bounding_boxes() layer_max_pooling_1d() layer_max_pooling_2d() layer_max_pooling_3d() layer_maximum() layer_mel_spectrogram() layer_minimum() layer_mix_up() layer_multi_head_attention() layer_multiply() layer_normalization() layer_permute() layer_rand_augment() layer_random_brightness() layer_random_color_degeneration() layer_random_color_jitter() layer_random_contrast() layer_random_crop() layer_random_erasing() layer_random_flip() layer_random_gaussian_blur() layer_random_grayscale() layer_random_hue() layer_random_invert() layer_random_perspective() layer_random_posterization() layer_random_rotation() layer_random_saturation() layer_random_sharpness() layer_random_shear() layer_random_translation() layer_random_zoom() layer_repeat_vector() layer_rescaling() layer_reshape() layer_resizing() layer_rms_normalization() layer_rnn() layer_separable_conv_1d() layer_separable_conv_2d() layer_simple_rnn() layer_solarization() layer_spatial_dropout_1d() layer_spatial_dropout_2d() layer_spatial_dropout_3d() layer_spectral_normalization() layer_stft_spectrogram() layer_string_lookup() layer_subtract() layer_text_vectorization() layer_tfsm() layer_time_distributed() layer_torch_module_wrapper() layer_unit_normalization() layer_upsampling_1d() layer_upsampling_2d() layer_upsampling_3d() layer_zero_padding_1d() layer_zero_padding_2d() layer_zero_padding_3d() rnn_cell_gru() rnn_cell_lstm() rnn_cell_simple() rnn_cells_stack()