A preprocessing layer which buckets continuous features by ranges.
Source:R/layers-preprocessing.R
layer_discretization.RdThis layer will place each element of its input data into one of several contiguous ranges and output an integer index indicating which range each element was placed in.
Note: This layer is safe to use inside a tf.data pipeline
(independently of which backend you're using).
Usage
layer_discretization(
object,
bin_boundaries = NULL,
num_bins = NULL,
epsilon = 0.01,
output_mode = "int",
sparse = FALSE,
dtype = NULL,
name = NULL
)Arguments
- object
Object to compose the layer with. A tensor, array, or sequential model.
- bin_boundaries
A list of bin boundaries. The leftmost and rightmost bins will always extend to
-InfandInf, sobin_boundaries = c(0, 1, 2)generates bins(-Inf, 0),[0, 1),[1, 2), and[2, +Inf). If this option is set,adapt()should not be called.- num_bins
The integer number of bins to compute. If this option is set,
bin_boundariesshould not be set andadapt()should be called to learn the bin boundaries.- epsilon
Error tolerance, typically a small fraction close to zero (e.g. 0.01). Higher values of epsilon increase the quantile approximation, and hence result in more unequal buckets, but could improve performance and resource consumption.
- output_mode
Specification for the output of the layer. Values can be
"int","one_hot","multi_hot", or"count"configuring the layer as follows:"int": Return the discretized bin indices directly."one_hot": Encodes each individual element in the input into an array the same size asnum_bins, containing a 1 at the input's bin index. If the last dimension is size 1, will encode on that dimension. If the last dimension is not size 1, will append a new dimension for the encoded output."multi_hot": Encodes each sample in the input into a single array the same size asnum_bins, containing a 1 for each bin index index present in the sample. Treats the last dimension as the sample dimension, if input shape is(..., sample_length), output shape will be(..., num_tokens)."count": As"multi_hot", but the int array contains a count of the number of times the bin index appeared in the sample. Defaults to"int".
- sparse
Boolean. Only applicable to
"one_hot","multi_hot", and"count"output modes. Only supported with TensorFlow backend. IfTRUE, returns aSparseTensorinstead of a denseTensor. Defaults toFALSE.- dtype
datatype (e.g.,
"float32").- name
String, name for the object
Value
The return value depends on the value provided for the first argument.
If object is:
a
keras_model_sequential(), then the layer is added to the sequential model (which is modified in place). To enable piping, the sequential model is also returned, invisibly.a
keras_input(), then the output tensor from callinglayer(input)is returned.NULLor missing, then aLayerinstance is returned.
Examples
Discretize float values based on provided buckets.
input <- op_array(rbind(c(-1.5, 1, 3.4, 0.5),
c(0, 3, 1.3, 0),
c(-.5, 0, .5, 1),
c(1.5, 2, 2.5, 3)))
output <- input |> layer_discretization(bin_boundaries = c(0, 1, 2))
outputDiscretize float values based on a number of buckets to compute.
layer <- layer_discretization(num_bins = 4, epsilon = 0.01)
layer |> adapt(input)
layer(input)See also
Other numerical features preprocessing layers: layer_normalization()
Other preprocessing layers: layer_aug_mix() layer_auto_contrast() layer_category_encoding() layer_center_crop() layer_cut_mix() layer_equalization() layer_feature_space() layer_hashed_crossing() layer_hashing() layer_integer_lookup() layer_max_num_bounding_boxes() layer_mel_spectrogram() layer_mix_up() layer_normalization() layer_rand_augment() layer_random_brightness() layer_random_color_degeneration() layer_random_color_jitter() layer_random_contrast() layer_random_crop() layer_random_erasing() layer_random_flip() layer_random_gaussian_blur() layer_random_grayscale() layer_random_hue() layer_random_invert() layer_random_perspective() layer_random_posterization() layer_random_rotation() layer_random_saturation() layer_random_sharpness() layer_random_shear() layer_random_translation() layer_random_zoom() layer_rescaling() layer_resizing() layer_solarization() layer_stft_spectrogram() layer_string_lookup() layer_text_vectorization()
Other layers: Layer() layer_activation() layer_activation_elu() layer_activation_leaky_relu() layer_activation_parametric_relu() layer_activation_relu() layer_activation_softmax() layer_activity_regularization() layer_add() layer_additive_attention() layer_alpha_dropout() layer_attention() layer_aug_mix() layer_auto_contrast() layer_average() layer_average_pooling_1d() layer_average_pooling_2d() layer_average_pooling_3d() layer_batch_normalization() layer_bidirectional() layer_category_encoding() layer_center_crop() layer_concatenate() layer_conv_1d() layer_conv_1d_transpose() layer_conv_2d() layer_conv_2d_transpose() layer_conv_3d() layer_conv_3d_transpose() layer_conv_lstm_1d() layer_conv_lstm_2d() layer_conv_lstm_3d() layer_cropping_1d() layer_cropping_2d() layer_cropping_3d() layer_cut_mix() layer_dense() layer_depthwise_conv_1d() layer_depthwise_conv_2d() layer_dot() layer_dropout() layer_einsum_dense() layer_embedding() layer_equalization() layer_feature_space() layer_flatten() layer_flax_module_wrapper() layer_gaussian_dropout() layer_gaussian_noise() layer_global_average_pooling_1d() layer_global_average_pooling_2d() layer_global_average_pooling_3d() layer_global_max_pooling_1d() layer_global_max_pooling_2d() layer_global_max_pooling_3d() layer_group_normalization() layer_group_query_attention() layer_gru() layer_hashed_crossing() layer_hashing() layer_identity() layer_integer_lookup() layer_jax_model_wrapper() layer_lambda() layer_layer_normalization() layer_lstm() layer_masking() layer_max_num_bounding_boxes() layer_max_pooling_1d() layer_max_pooling_2d() layer_max_pooling_3d() layer_maximum() layer_mel_spectrogram() layer_minimum() layer_mix_up() layer_multi_head_attention() layer_multiply() layer_normalization() layer_permute() layer_rand_augment() layer_random_brightness() layer_random_color_degeneration() layer_random_color_jitter() layer_random_contrast() layer_random_crop() layer_random_erasing() layer_random_flip() layer_random_gaussian_blur() layer_random_grayscale() layer_random_hue() layer_random_invert() layer_random_perspective() layer_random_posterization() layer_random_rotation() layer_random_saturation() layer_random_sharpness() layer_random_shear() layer_random_translation() layer_random_zoom() layer_repeat_vector() layer_rescaling() layer_reshape() layer_resizing() layer_rms_normalization() layer_rnn() layer_separable_conv_1d() layer_separable_conv_2d() layer_simple_rnn() layer_solarization() layer_spatial_dropout_1d() layer_spatial_dropout_2d() layer_spatial_dropout_3d() layer_spectral_normalization() layer_stft_spectrogram() layer_string_lookup() layer_subtract() layer_text_vectorization() layer_tfsm() layer_time_distributed() layer_torch_module_wrapper() layer_unit_normalization() layer_upsampling_1d() layer_upsampling_2d() layer_upsampling_3d() layer_zero_padding_1d() layer_zero_padding_2d() layer_zero_padding_3d() rnn_cell_gru() rnn_cell_lstm() rnn_cell_simple() rnn_cells_stack()