Skip to contents

Schematically, op_vectorized_map() maps over the first dimension of the provided tensors. If elements is a list of tensors, then each of the tensors are required to have the same size first dimension, and they are iterated over together.

Usage

op_vectorized_map(elements, f)

Arguments

elements

see description

f

A function taking either a tensor, or list of tensors.

Value

A tensor or list of tensors, the result of mapping f across elements.

Examples

(x <- op_arange(12L) |> op_reshape(c(3, 4)))

## tf.Tensor(
## [[ 0  1  2  3]
##  [ 4  5  6  7]
##  [ 8  9 10 11]], shape=(3, 4), dtype=int32)

x |> op_vectorized_map(\(row) {row + 10})

## tf.Tensor(
## [[10 11 12 13]
##  [14 15 16 17]
##  [18 19 20 21]], shape=(3, 4), dtype=int32)

list(x, x, x) |> op_vectorized_map(\(rows) Reduce(`+`, rows))

## tf.Tensor(
## [[ 0  3  6  9]
##  [12 15 18 21]
##  [24 27 30 33]], shape=(3, 4), dtype=int32)

Note that f may be traced and compiled. Meaning, the R function may only evaluated once with symbolic tensors if using Jax or TensorFlow backends, and not with eager tensors. See the output from str() in these examples:

# simplest case, map f over rows of x,
# where .x is 1 row of x
input <- x
output <- op_vectorized_map(input, function(.x) {
  str(.x)
  .x + 10
})

## <tf.Tensor 'loop_body/GatherV2:0' shape=(4) dtype=int32>

output

## tf.Tensor(
## [[10 11 12 13]
##  [14 15 16 17]
##  [18 19 20 21]], shape=(3, 4), dtype=int32)

# map f over two tensors simultaneously. Here, # `.x` is a list of two
# tensors. The return values from each call of `f(row)` are stacked to form the
# final output
input <- list(x, x)
output <- op_vectorized_map(input, function(.x) {
  str(.x)
  .x[[1]] + 10
})

## List of 2
##  $ :<tf.Tensor 'loop_body/GatherV2:0' shape=(4) dtype=int32>
##  $ :<tf.Tensor 'loop_body/GatherV2_1:0' shape=(4) dtype=int32>

output

## tf.Tensor(
## [[10 11 12 13]
##  [14 15 16 17]
##  [18 19 20 21]], shape=(3, 4), dtype=int32)

# same as above, but now returning two tensors in the final output
output <- op_vectorized_map(input, function(.x) {
  str(.x)
  c(.x1, .x2) %<-% .x
  list(.x1+10, .x2+20)
})

## List of 2
##  $ :<tf.Tensor 'loop_body/GatherV2:0' shape=(4) dtype=int32>
##  $ :<tf.Tensor 'loop_body/GatherV2_1:0' shape=(4) dtype=int32>

output

## [[1]]
## tf.Tensor(
## [[10 11 12 13]
##  [14 15 16 17]
##  [18 19 20 21]], shape=(3, 4), dtype=int32)
##
## [[2]]
## tf.Tensor(
## [[20 21 22 23]
##  [24 25 26 27]
##  [28 29 30 31]], shape=(3, 4), dtype=int32)

# passing named lists.
# WARNING: if passing a named list, the order of elements of `.x` supplied
# to `f` is not stable. Only retrieve elements by name.
input <- list(name1 = x, name2 = x)
output <- op_vectorized_map(input, function(.x) {
  str(.x)
  list(outname1 = .x$name1 + 10,
       outname2 = .x$name2 + 20)
})

## List of 2
##  $ name1:<tf.Tensor 'loop_body/GatherV2:0' shape=(4) dtype=int32>
##  $ name2:<tf.Tensor 'loop_body/GatherV2_1:0' shape=(4) dtype=int32>

output

## $outname1
## tf.Tensor(
## [[10 11 12 13]
##  [14 15 16 17]
##  [18 19 20 21]], shape=(3, 4), dtype=int32)
##
## $outname2
## tf.Tensor(
## [[20 21 22 23]
##  [24 25 26 27]
##  [28 29 30 31]], shape=(3, 4), dtype=int32)

# passing a tuple() is equivalent to passing an unnamed list()
input <- tuple(x, x)
output <- op_vectorized_map(input, function(.x) {
  str(.x)
  list(.x[[1]] + 10)
})

## List of 2
##  $ :<tf.Tensor 'loop_body/GatherV2:0' shape=(4) dtype=int32>
##  $ :<tf.Tensor 'loop_body/GatherV2_1:0' shape=(4) dtype=int32>

output

## [[1]]
## tf.Tensor(
## [[10 11 12 13]
##  [14 15 16 17]
##  [18 19 20 21]], shape=(3, 4), dtype=int32)

Debugging f

Even in eager contexts, op_vectorized_map() may trace f. In that case, if you want to eagerly debug f (e.g., with browser()), you can swap in a manual (slow) implementation of op_vectorized_map(). Note this example debug implementation does not handle all the same edge cases as op_vectorized_map(), in particular, if f returns a structure of multiple tensors.

op_vectorized_map_debug <- function(elements, fn) {

  if (!is.list(elements)) {
    # `elements` is a single tensor
    batch_size <- op_shape(elements)[[1]]
    out <- elements |>
      op_split(batch_size) |>
      lapply(fn) |>
      op_stack()
    return(out)
  }

  # `elements` is a list of tensors
  batch_size <- elements[[1]] |> op_shape() |> _[[1]]
  elements |>
    lapply(\(e) op_split(e, batch_size)) |>
    zip_lists() |>
    lapply(fn) |>
    op_stack()

}

See also

Other core ops:
op_associative_scan()
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_custom_gradient()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_while_loop()

Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_greater()
op_greater_equal()
op_hard_sigmoid()
op_hard_silu()
op_hstack()
op_identity()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_roll()
op_round()
op_rsqrt()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_split()
op_sqrt()
op_square()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tensordot()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()