This function takes as input logits, a 2-D input tensor with shape
(batch_size, num_classes). Each row of the input represents a categorical
distribution, with each column index containing the log-probability for a
given class.
The function will output a 2-D tensor with shape (batch_size, num_samples),
where each row contains samples from the corresponding row in logits.
Each column index contains an independent samples drawn from the input
distribution.
x <- matrix(c(100, .1, 99), nrow = 1)
random_categorical(x, num_samples = 5, seed = 1234)random_categorical(x, num_samples = 5, seed = 1234,
                   zero_indexed = TRUE)op_take(x, random_categorical(x, num_samples = 5, seed = 1234))op_take(x, random_categorical(x, num_samples = 5, seed = 1234,
                              zero_indexed = TRUE),
        zero_indexed = TRUE)Arguments
- logits
- 2-D Tensor with shape (batch_size, num_classes). Each row should define a categorical distribution with the unnormalized log-probabilities for all classes. 
- num_samples
- Int, the number of independent samples to draw for each row of the input. This will be the second dimension of the output tensor's shape. 
- dtype
- Optional dtype of the output tensor. 
- seed
- Optional R integer or instance of - random_seed_generator(). By default, the- seedargument is- NULL, and an internal global- random_seed_generator()is used. The- seedargument can be used to ensure deterministic (repeatable) random number generation. Note that passing an integer as the- seedvalue will produce the same random values for each call. To generate different random values for repeated calls, an instance of- random_seed_generator()must be provided as the- seedvalue.- Remark concerning the JAX backend: When tracing functions with the JAX backend the global - random_seed_generator()is not supported. Therefore, during tracing the default value- seed=NULLwill produce an error, and a- seedargument must be provided.
- zero_indexed
- If - TRUE, the returned indices are zero-based (- 0encodes to first position); if- FALSE(default), the returned indices are one-based (- 1encodes to first position).
See also
Other random: random_beta() random_binomial() random_dropout() random_gamma() random_integer() random_normal() random_seed_generator() random_shuffle() random_truncated_normal() random_uniform()