When the type of xs is an array type or NULL, and the type of ys is an
array type, the semantics of op_scan() are given roughly by this
implementation:
op_scan <- function(f, init, xs = NULL, length = NULL) {
xs <- xs %||% vector("list", length)
if(!is.list(xs))
xs <- op_unstack(xs)
ys <- vector("list", length(xs))
carry <- init
for (i in seq_along(xs)) {
c(carry, y) %<-% f(carry, xs[[i]])
ys[[i]] <- y
}
list(carry, op_stack(ys))
}The loop-carried value carry (init) must hold a fixed shape and dtype
across all iterations.
In TensorFlow, y must match carry in shape and dtype. This is not
required in other backends.
Arguments
- f
Callable defines the logic for each loop iteration. This accepts two arguments where the first is a value of the loop carry and the second is a slice of
xsalong its leading axis. This callable returns a pair where the first represents a new value for the loop carry and the second represents a slice of the output.- init
The initial loop carry value. This can be a scalar, tensor, or any nested structure. It must match the structure of the first element returned by
f.- xs
Optional value to scan along its leading axis. This can be a tensor or any nested structure. If
xsis not provided, you must specifylengthto define the number of loop iterations. Defaults toNULL.- length
Optional integer specifying the number of loop iterations. If
lengthis not provided, it defaults to the sizes of leading axis of the arrays inxs. Defaults toNULL.- reverse
Optional boolean specifying whether to run the scan iteration forward or in reverse, equivalent to reversing the leading axes of the arrays in both
xsand inys.- unroll
Optional positive integer or boolean specifying how many scan iterations to unroll within a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is completely unrolled (
unroll=TRUE) or left completely unrolled (unroll=FALSE). Note that unrolling is only supported by JAX and TensorFlow backends.
Value
A pair where the first element represents the final loop carry value and
the second element represents the stacked outputs of f when scanned
over the leading axis of the inputs.
Examples
sum_fn <- function(c, x) list(c + x, c + x)
init <- op_array(0L)
xs <- op_array(1:5)
c(carry, result) %<-% op_scan(sum_fn, init, xs)
carryresultSee also
Other core ops: op_associative_scan() op_cast() op_cond() op_convert_to_numpy() op_convert_to_tensor() op_custom_gradient() op_dtype() op_fori_loop() op_is_tensor() op_map() op_rearrange() op_scatter() op_scatter_update() op_searchsorted() op_shape() op_slice() op_slice_update() op_stop_gradient() op_subset() op_switch() op_unstack() op_vectorized_map() op_while_loop()
Other ops: op_abs() op_add() op_all() op_any() op_append() op_arange() op_arccos() op_arccosh() op_arcsin() op_arcsinh() op_arctan() op_arctan2() op_arctanh() op_argmax() op_argmin() op_argpartition() op_argsort() op_array() op_associative_scan() op_average() op_average_pool() op_batch_normalization() op_binary_crossentropy() op_bincount() op_bitwise_and() op_bitwise_invert() op_bitwise_left_shift() op_bitwise_not() op_bitwise_or() op_bitwise_right_shift() op_bitwise_xor() op_broadcast_to() op_cast() op_categorical_crossentropy() op_ceil() op_celu() op_cholesky() op_clip() op_concatenate() op_cond() op_conj() op_conv() op_conv_transpose() op_convert_to_numpy() op_convert_to_tensor() op_copy() op_correlate() op_cos() op_cosh() op_count_nonzero() op_cross() op_ctc_decode() op_ctc_loss() op_cumprod() op_cumsum() op_custom_gradient() op_depthwise_conv() op_det() op_diag() op_diagflat() op_diagonal() op_diff() op_digitize() op_divide() op_divide_no_nan() op_dot() op_dot_product_attention() op_dtype() op_eig() op_eigh() op_einsum() op_elu() op_empty() op_equal() op_erf() op_erfinv() op_exp() op_exp2() op_expand_dims() op_expm1() op_extract_sequences() op_eye() op_fft() op_fft2() op_flip() op_floor() op_floor_divide() op_fori_loop() op_full() op_full_like() op_gelu() op_get_item() op_glu() op_greater() op_greater_equal() op_hard_shrink() op_hard_sigmoid() op_hard_silu() op_hard_tanh() op_histogram() op_hstack() op_identity() op_ifft2() op_imag() op_image_affine_transform() op_image_crop() op_image_extract_patches() op_image_gaussian_blur() op_image_hsv_to_rgb() op_image_map_coordinates() op_image_pad() op_image_perspective_transform() op_image_resize() op_image_rgb_to_grayscale() op_image_rgb_to_hsv() op_in_top_k() op_inner() op_inv() op_irfft() op_is_tensor() op_isclose() op_isfinite() op_isinf() op_isnan() op_istft() op_leaky_relu() op_left_shift() op_less() op_less_equal() op_linspace() op_log() op_log10() op_log1p() op_log2() op_log_sigmoid() op_log_softmax() op_logaddexp() op_logdet() op_logical_and() op_logical_not() op_logical_or() op_logical_xor() op_logspace() op_logsumexp() op_lstsq() op_lu_factor() op_map() op_matmul() op_max() op_max_pool() op_maximum() op_mean() op_median() op_meshgrid() op_min() op_minimum() op_mod() op_moments() op_moveaxis() op_multi_hot() op_multiply() op_nan_to_num() op_ndim() op_negative() op_nonzero() op_norm() op_normalize() op_not_equal() op_one_hot() op_ones() op_ones_like() op_outer() op_pad() op_polar() op_power() op_prod() op_psnr() op_qr() op_quantile() op_ravel() op_real() op_rearrange() op_reciprocal() op_relu() op_relu6() op_repeat() op_reshape() op_rfft() op_right_shift() op_rms_normalization() op_roll() op_rot90() op_round() op_rsqrt() op_saturate_cast() op_scatter() op_scatter_update() op_searchsorted() op_segment_max() op_segment_sum() op_select() op_selu() op_separable_conv() op_shape() op_sigmoid() op_sign() op_signbit() op_silu() op_sin() op_sinh() op_size() op_slice() op_slice_update() op_slogdet() op_soft_shrink() op_softmax() op_softplus() op_softsign() op_solve() op_solve_triangular() op_sort() op_sparse_categorical_crossentropy() op_sparse_plus() op_sparsemax() op_split() op_sqrt() op_square() op_squareplus() op_squeeze() op_stack() op_std() op_stft() op_stop_gradient() op_subset() op_subtract() op_sum() op_svd() op_swapaxes() op_switch() op_take() op_take_along_axis() op_tan() op_tanh() op_tanh_shrink() op_tensordot() op_threshold() op_tile() op_top_k() op_trace() op_transpose() op_tri() op_tril() op_triu() op_trunc() op_unravel_index() op_unstack() op_var() op_vdot() op_vectorize() op_vectorized_map() op_vstack() op_where() op_while_loop() op_zeros() op_zeros_like()