This function takes as input logits
, a 2-D input tensor with shape
(batch_size, num_classes). Each row of the input represents a categorical
distribution, with each column index containing the log-probability for a
given class.
The function will output a 2-D tensor with shape (batch_size, num_samples),
where each row contains samples from the corresponding row in logits
.
Each column index contains an independent samples drawn from the input
distribution.
Arguments
- logits
2-D Tensor with shape (batch_size, num_classes). Each row should define a categorical distribution with the unnormalized log-probabilities for all classes.
- num_samples
Int, the number of independent samples to draw for each row of the input. This will be the second dimension of the output tensor's shape.
- dtype
Optional dtype of the output tensor.
- seed
Optional R integer or instance of
random_seed_generator()
. By default, theseed
argument isNULL
, and an internal globalrandom_seed_generator()
is used. Theseed
argument can be used to ensure deterministic (repeatable) random number generation. Note that passing an integer as theseed
value will produce the same random values for each call. To generate different random values for repeated calls, an instance ofrandom_seed_generator()
must be provided as theseed
value.Remark concerning the JAX backend: When tracing functions with the JAX backend the global
random_seed_generator()
is not supported. Therefore, during tracing the default valueseed=NULL
will produce an error, and aseed
argument must be provided.
See also
Other random: random_beta()
random_binomial()
random_dropout()
random_gamma()
random_integer()
random_normal()
random_seed_generator()
random_shuffle()
random_truncated_normal()
random_uniform()