Skip to contents

Computes the attention function on Q (query), K (key), and V(value): attention(Q, K, V) = softmax(Q * K / sqrt(d)) * V. If we define logits as the output of Q * K and the probs as the output of softmax.

Throughout this function, we utilize the following notation to represent the shape of array:

  • B: batch size

  • S: length of the key/value

  • T: length of the query

  • N: number of attention heads

  • H: dimensions of each attention head

  • K: number of key/value heads

  • G: number of groups, which equals to N // K

Usage

op_dot_product_attention(
  query,
  key,
  value,
  bias = NULL,
  mask = NULL,
  scale = NULL,
  is_causal = FALSE,
  flash_attention = NULL
)

Arguments

query

The query array with the shape of (B, T, N, H).

key

The key array with the shape of (B, S, K, H). When K equals N, multi-headed attention (MHA) is performed. Otherwise, grouped query attention (GQA) is performed if N is a multiple of K. and multi-query attention (MQA) is performed if K==1 (a special case of GQA).

value

The value array with the same shape of key.

bias

Optional bias array to be added to logits. The shape must be broadcastable to (B, N, T, S).

mask

Optional mask array used to filter out logits. It is a boolean mask where TRUE indicates the element should take part in attention. For an additive mask, users should pass it to bias. The shape must be broadcastable to (B, N, T, S).

scale

Optional scale for the logits. If NULL, the scale will be set to 1.0 / sqrt(H).

is_causal

Whether to apply causal mask.

flash_attention

Whether to use flash attention. If NULL, it will attempt to use flash attention if the required conditions are met. Typically, the inputs must be in float16 and bfloat16 dtype and the input layout requirements may vary depending on the backend.

Value

An array of the attention output with the same shape of query.

Examples

query = random_normal(c(2, 4, 8, 16))
key = random_normal(c(2, 6, 8, 16))
value = random_normal(c(2, 6, 8, 16))
op_dot_product_attention(query, key, value) |> op_shape()

## shape(2, 4, 8, 16)

See also

Other nn ops:
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_categorical_crossentropy()
op_celu()
op_conv()
op_conv_transpose()
op_ctc_loss()
op_depthwise_conv()
op_elu()
op_gelu()
op_glu()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_leaky_relu()
op_log_sigmoid()
op_log_softmax()
op_max_pool()
op_moments()
op_multi_hot()
op_normalize()
op_one_hot()
op_psnr()
op_relu()
op_relu6()
op_selu()
op_separable_conv()
op_sigmoid()
op_silu()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_squareplus()
op_tanh_shrink()
op_threshold()
op_unravel_index()

Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_bitwise_and()
op_bitwise_invert()
op_bitwise_left_shift()
op_bitwise_not()
op_bitwise_or()
op_bitwise_right_shift()
op_bitwise_xor()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_celu()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagflat()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_exp2()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_glu()
op_greater()
op_greater_equal()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_histogram()
op_hstack()
op_identity()
op_ifft2()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inner()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_left_shift()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logdet()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_right_shift()
op_roll()
op_round()
op_rsqrt()
op_saturate_cast()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_split()
op_sqrt()
op_square()
op_squareplus()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tanh_shrink()
op_tensordot()
op_threshold()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_trunc()
op_unravel_index()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()