Performs a scan with an associative binary operation, in parallel.
Source:R/ops.R
op_associative_scan.RdThis operation his similar to op_scan(), with the key difference that
op_associative_scan() is a parallel implementation with
potentially significant performance benefits, especially when jit compiled.
The catch is that it can only be used when f is a binary associative
operation (i.e. it must verify f(a, f(b, c)) == f(f(a, b), c)).
For an introduction to associative scans, refer to this paper: Blelloch, Guy E. 1990. Prefix Sums and Their Applications.
Arguments
- f
A callable implementing an associative binary operation with signature
r = f(a, b). Functionfmust be associative, i.e., it must satisfy the equationf(a, f(b, c)) == f(f(a, b), c). The inputs and result are (possibly nested tree structures of) array(s) matchingelems. Each array has a dimension in place of theaxisdimension.fshould be applied elementwise over theaxisdimension. The resultrhas the same shape (and structure) as the two inputsaandb.- elems
A (possibly nested tree structure of) array(s), each with an
axisdimension of sizenum_elems.- reverse
A boolean stating if the scan should be reversed with respect to the
axisdimension.- axis
an integer identifying the axis over which the scan should occur.
Value
A (possibly nested tree structure of) array(s) of the same shape
and structure as elems, in which the k'th element of axis is
the result of recursively applying f to combine the first k
elements of elems along axis. For example, given
elems = list(a, b, c, ...), the result would be
list(a, f(a, b), f(f(a, b), c), ...).
Examples
sum_fn <- function(x, y) x + y
xs <- op_arange(5L)
op_associative_scan(sum_fn, xs)sum_fn <- function(x, y) {
str(list(x = x, y = y))
map2(x, y, \(.x, .y) .x + .y)
}
xs <- list(op_array(1:2),
op_array(1:2),
op_array(1:2))
ys <- op_associative_scan(sum_fn, xs, axis = 1)## List of 2
## $ x:List of 3
## ..$ :<tf.Tensor: shape=(1), dtype=int32, numpy=array([1], dtype=int32)>
## ..$ :<tf.Tensor: shape=(1), dtype=int32, numpy=array([1], dtype=int32)>
## ..$ :<tf.Tensor: shape=(1), dtype=int32, numpy=array([1], dtype=int32)>
## $ y:List of 3
## ..$ :<tf.Tensor: shape=(1), dtype=int32, numpy=array([2], dtype=int32)>
## ..$ :<tf.Tensor: shape=(1), dtype=int32, numpy=array([2], dtype=int32)>
## ..$ :<tf.Tensor: shape=(1), dtype=int32, numpy=array([2], dtype=int32)>
ysSee also
Other core ops: op_cast() op_cond() op_convert_to_numpy() op_convert_to_tensor() op_custom_gradient() op_dtype() op_fori_loop() op_is_tensor() op_map() op_rearrange() op_scan() op_scatter() op_scatter_update() op_searchsorted() op_shape() op_slice() op_slice_update() op_stop_gradient() op_subset() op_switch() op_unstack() op_vectorized_map() op_while_loop()
Other ops: op_abs() op_add() op_all() op_angle() op_any() op_append() op_arange() op_arccos() op_arccosh() op_arcsin() op_arcsinh() op_arctan() op_arctan2() op_arctanh() op_argmax() op_argmin() op_argpartition() op_argsort() op_array() op_average() op_average_pool() op_bartlett() op_batch_normalization() op_binary_crossentropy() op_bincount() op_bitwise_and() op_bitwise_invert() op_bitwise_left_shift() op_bitwise_not() op_bitwise_or() op_bitwise_right_shift() op_bitwise_xor() op_blackman() op_broadcast_to() op_cast() op_categorical_crossentropy() op_cbrt() op_ceil() op_celu() op_cholesky() op_clip() op_concatenate() op_cond() op_conj() op_conv() op_conv_transpose() op_convert_to_numpy() op_convert_to_tensor() op_copy() op_corrcoef() op_correlate() op_cos() op_cosh() op_count_nonzero() op_cross() op_ctc_decode() op_ctc_loss() op_cumprod() op_cumsum() op_custom_gradient() op_deg2rad() op_depthwise_conv() op_det() op_diag() op_diagflat() op_diagonal() op_diff() op_digitize() op_divide() op_divide_no_nan() op_dot() op_dot_product_attention() op_dtype() op_eig() op_eigh() op_einsum() op_elu() op_empty() op_equal() op_erf() op_erfinv() op_exp() op_exp2() op_expand_dims() op_expm1() op_extract_sequences() op_eye() op_fft() op_fft2() op_flip() op_floor() op_floor_divide() op_fori_loop() op_full() op_full_like() op_gelu() op_get_item() op_glu() op_greater() op_greater_equal() op_hamming() op_hanning() op_hard_shrink() op_hard_sigmoid() op_hard_silu() op_hard_tanh() op_heaviside() op_histogram() op_hstack() op_identity() op_ifft2() op_imag() op_image_affine_transform() op_image_crop() op_image_extract_patches() op_image_gaussian_blur() op_image_hsv_to_rgb() op_image_map_coordinates() op_image_pad() op_image_perspective_transform() op_image_resize() op_image_rgb_to_grayscale() op_image_rgb_to_hsv() op_in_top_k() op_inner() op_inv() op_irfft() op_is_tensor() op_isclose() op_isfinite() op_isinf() op_isnan() op_istft() op_kaiser() op_layer_normalization() op_leaky_relu() op_left_shift() op_less() op_less_equal() op_linspace() op_log() op_log10() op_log1p() op_log2() op_log_sigmoid() op_log_softmax() op_logaddexp() op_logdet() op_logical_and() op_logical_not() op_logical_or() op_logical_xor() op_logspace() op_logsumexp() op_lstsq() op_lu_factor() op_map() op_matmul() op_max() op_max_pool() op_maximum() op_mean() op_median() op_meshgrid() op_min() op_minimum() op_mod() op_moments() op_moveaxis() op_multi_hot() op_multiply() op_nan_to_num() op_ndim() op_negative() op_nonzero() op_norm() op_normalize() op_not_equal() op_one_hot() op_ones() op_ones_like() op_outer() op_pad() op_polar() op_power() op_prod() op_psnr() op_qr() op_quantile() op_ravel() op_real() op_rearrange() op_reciprocal() op_relu() op_relu6() op_repeat() op_reshape() op_rfft() op_right_shift() op_rms_normalization() op_roll() op_rot90() op_round() op_rsqrt() op_saturate_cast() op_scan() op_scatter() op_scatter_update() op_searchsorted() op_segment_max() op_segment_sum() op_select() op_selu() op_separable_conv() op_shape() op_sigmoid() op_sign() op_signbit() op_silu() op_sin() op_sinh() op_size() op_slice() op_slice_update() op_slogdet() op_soft_shrink() op_softmax() op_softplus() op_softsign() op_solve() op_solve_triangular() op_sort() op_sparse_categorical_crossentropy() op_sparse_plus() op_sparse_sigmoid() op_sparsemax() op_split() op_sqrt() op_square() op_squareplus() op_squeeze() op_stack() op_std() op_stft() op_stop_gradient() op_subset() op_subtract() op_sum() op_svd() op_swapaxes() op_switch() op_take() op_take_along_axis() op_tan() op_tanh() op_tanh_shrink() op_tensordot() op_threshold() op_tile() op_top_k() op_trace() op_transpose() op_tri() op_tril() op_triu() op_trunc() op_unravel_index() op_unstack() op_var() op_vdot() op_vectorize() op_vectorized_map() op_view_as_complex() op_view_as_real() op_vstack() op_where() op_while_loop() op_zeros() op_zeros_like()