Computes the attention function on Q (query
), K (key
), and V(value
):
attention(Q, K, V) = softmax(Q * K / sqrt(d)) * V
. If we define logits
as the output of Q * K
and the probs
as the output of softmax
.
Throughout this function, we utilize the following notation to represent the shape of array:
B: batch size
S: length of the key/value
T: length of the query
N: number of attention heads
H: dimensions of each attention head
K: number of key/value heads
G: number of groups, which equals to
N // K
Usage
op_dot_product_attention(
query,
key,
value,
bias = NULL,
mask = NULL,
scale = NULL,
is_causal = FALSE,
flash_attention = NULL
)
Arguments
- query
The query array with the shape of
(B, T, N, H)
.- key
The key array with the shape of
(B, S, K, H)
. WhenK
equalsN
, multi-headed attention (MHA) is performed. Otherwise, grouped query attention (GQA) is performed ifN
is a multiple ofK
. and multi-query attention (MQA) is performed ifK==1
(a special case of GQA).- value
The value array with the same shape of
key
.- bias
Optional bias array to be added to logits. The shape must be broadcastable to
(B, N, T, S)
.- mask
Optional mask array used to filter out logits. It is a boolean mask where
TRUE
indicates the element should take part in attention. For an additive mask, users should pass it to bias. The shape must be broadcastable to(B, N, T, S)
.- scale
Optional scale for the logits. If
NULL
, the scale will be set to1.0 / sqrt(H)
.- is_causal
Whether to apply causal mask.
- flash_attention
Whether to use flash attention. If
NULL
, it will attempt to use flash attention if the required conditions are met. Typically, the inputs must be in float16 and bfloat16 dtype and the input layout requirements may vary depending on the backend.
Examples
query = random_normal(c(2, 4, 8, 16))
key = random_normal(c(2, 6, 8, 16))
value = random_normal(c(2, 6, 8, 16))
op_dot_product_attention(query, key, value) |> op_shape()
See also
Other nn ops: op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_categorical_crossentropy()
op_celu()
op_conv()
op_conv_transpose()
op_ctc_loss()
op_depthwise_conv()
op_elu()
op_gelu()
op_glu()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_leaky_relu()
op_log_sigmoid()
op_log_softmax()
op_max_pool()
op_moments()
op_multi_hot()
op_normalize()
op_one_hot()
op_psnr()
op_relu()
op_relu6()
op_selu()
op_separable_conv()
op_sigmoid()
op_silu()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_squareplus()
op_tanh_shrink()
op_threshold()
op_unravel_index()
Other ops: op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_bitwise_and()
op_bitwise_invert()
op_bitwise_left_shift()
op_bitwise_not()
op_bitwise_or()
op_bitwise_right_shift()
op_bitwise_xor()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_celu()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagflat()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_exp2()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_glu()
op_greater()
op_greater_equal()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_histogram()
op_hstack()
op_identity()
op_ifft2()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inner()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_left_shift()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logdet()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_right_shift()
op_roll()
op_round()
op_rsqrt()
op_saturate_cast()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_split()
op_sqrt()
op_square()
op_squareplus()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tanh_shrink()
op_tensordot()
op_threshold()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_trunc()
op_unravel_index()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()