Skip to contents

Extract elements from a tensor using common R-style [ indexing idioms. This function can also be conveniently accessed via the syntax tensor@r[...].

Usage

op_subset(x, ...)

op_subset(x, ...) <- value

op_subset_set(x, ..., value)

Arguments

x

Input tensor.

...

Indices specifying elements to extract. Each argument in ... can be:

  • An integer scalar

  • A 1-d integer or logical vector

  • NULL or newaxis

  • The .. symbol

  • A slice expression using :

If only a single argument is supplied to ..., then ..1 can also be:

  • A logical array with the same shape as x

  • An integer matrix where ncol(..1) == op_rank(x)

value

new value to replace the selected subset with.

Value

A tensor containing the subset of elements.

Details

While the semantics are similar to R's [, there are some differences:

Differences from R's [:

  • Negative indices follow Python-style indexing, counting from the end of the array.

  • NULL or newaxis adds a new dimension (equivalent to op_expand_dims()).

  • If fewer indices than dimensions (op_rank(x)) are provided, missing dimensions are implicitly filled. For example, if x is a matrix, x[1] returns the first row.

  • .. or all_dims() expands to include all unspecified dimensions (see examples).

  • Extended slicing syntax (:) is supported, including:

    • Strided steps: x@r[start:end:step]

    • NA values for start and end. NA for start defaults to 1, and NA for end defaults to the axis size.

  • A logical array matching the shape of x selects elements in row-wise order.

Similarities with R's [:

Similarities to R's [ (differences from Python's [):

  • Positive indices are 1-based.

  • Slices (x[start:end]) are inclusive of end.

  • 1-d logical/integer arrays subset along their respective axis. Multiple vectors provided for different axes return intersected subsets.

  • A single integer matrix with ncol(i) == op_rank(x) selects elements by coordinates. Each row in the matrix specifies the location of one value, where each column corresponds to an axis in the tensor being subsetted. This means you use a 2-column matrix to subset a matrix, a 3-column matrix to subset a 3d array, and so on.

Examples

(x <- op_arange(5L) + 10L)

## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

# Basic example, get first element
op_subset(x, 1)

## tf.Tensor(11, shape=(), dtype=int32)

# Use `@r[` syntax
x@r[1]           # same as `op_subset(x, 1)`

## tf.Tensor(11, shape=(), dtype=int32)

x@r[1:2]         # get the first 2 elements

## tf.Tensor([11 12], shape=(2), dtype=int32)

x@r[c(1, 3)]     # first and third element

## tf.Tensor([11 13], shape=(2), dtype=int32)

# Negative indices
x@r[-1]          # last element

## tf.Tensor(15, shape=(), dtype=int32)

x@r[-2]          # second to last element

## tf.Tensor(14, shape=(), dtype=int32)

x@r[c(-1, -2)]   # last and second to last elements

## tf.Tensor([15 14], shape=(2), dtype=int32)

x@r[c(-2, -1)]   # second to last and last elements

## tf.Tensor([14 15], shape=(2), dtype=int32)

x@r[c(1, -1)]    # first and last elements

## tf.Tensor([11 15], shape=(2), dtype=int32)

# Slices
x@r[1:3]          # first 3 elements

## tf.Tensor([11 12 13], shape=(3), dtype=int32)

x@r[NA:3]         # first 3 elements

## tf.Tensor([11 12 13], shape=(3), dtype=int32)

x@r[1:5]          # all elements

## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

x@r[1:-1]         # all elements

## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

x@r[NA:NA]        # all elements

## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

x@r[]             # all elements

## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

x@r[1:-2]         # drop last element

## tf.Tensor([11 12 13 14], shape=(4), dtype=int32)

x@r[NA:-2]        # drop last element

## tf.Tensor([11 12 13 14], shape=(4), dtype=int32)

x@r[2:NA]         # drop first element

## tf.Tensor([12 13 14 15], shape=(4), dtype=int32)

# 2D array examples
xr <- array(1:12, c(3, 4))
x <- op_convert_to_tensor(xr)

# Basic subsetting
x@r[1, ]      # first row

## tf.Tensor([ 1  4  7 10], shape=(4), dtype=int32)

x@r[1]        # also first row! Missing axes are implicitly inserted

## tf.Tensor([ 1  4  7 10], shape=(4), dtype=int32)

x@r[-1]       # last row

## tf.Tensor([ 3  6  9 12], shape=(4), dtype=int32)

x@r[, 2]      # second column

## tf.Tensor([4 5 6], shape=(3), dtype=int32)

x@r[, 2:2]    # second column, but shape preserved (like [, drop=FALSE])

## tf.Tensor(
## [[4]
##  [5]
##  [6]], shape=(3, 1), dtype=int32)

# Subsetting with a boolean array
# Note: extracted elements are selected row-wise, not column-wise
mask <- x >= 6
x@r[mask]             # returns a 1D tensor

## tf.Tensor([ 7 10  8 11  6  9 12], shape=(7), dtype=int32)

x.r <- as.array(x)
mask.r <- as.array(mask)
# as.array(x)[mask] selects column-wise. Use `aperm()` to reverse search order.
all(aperm(x.r)[aperm(mask.r)] == as.array(x@r[mask]))

## [1] TRUE

# Subsetting with a matrix of index positions
indices <- rbind(c(1, 1), c(2, 2), c(3, 3))
x@r[indices] # get diagonal elements

## tf.Tensor([1 5 9], shape=(3), dtype=int32)

x.r[indices] # same as subsetting an R array

## [1] 1 5 9

# 3D array examples
# Image: 4x4 pixels, 3 colors (RGB)
# Tensor shape: (img_height, img_width, img_color_channels)
shp <- shape(4, 4, 3)
x <- op_arange(prod(shp)) |> op_reshape(shp)

# Convert to a batch of images by inserting a new axis
# New shape: (batch_size, img_height, img_width, img_color_channels)
x@r[newaxis, , , ] |> op_shape()

## shape(1, 4, 4, 3)

x@r[newaxis] |> op_shape()  # same as above

## shape(1, 4, 4, 3)

x@r[NULL] |> op_shape()     # same as above

## shape(1, 4, 4, 3)

x <- x@r[newaxis]
# Extract color channels
x@r[, , , 1]          # red channel

## tf.Tensor(
## [[[ 1.  4.  7. 10.]
##   [13. 16. 19. 22.]
##   [25. 28. 31. 34.]
##   [37. 40. 43. 46.]]], shape=(1, 4, 4), dtype=float32)

x@r[.., 1]            # red channel, same as above using .. shorthand

## tf.Tensor(
## [[[ 1.  4.  7. 10.]
##   [13. 16. 19. 22.]
##   [25. 28. 31. 34.]
##   [37. 40. 43. 46.]]], shape=(1, 4, 4), dtype=float32)

x@r[.., 2]            # green channel

## tf.Tensor(
## [[[ 2.  5.  8. 11.]
##   [14. 17. 20. 23.]
##   [26. 29. 32. 35.]
##   [38. 41. 44. 47.]]], shape=(1, 4, 4), dtype=float32)

x@r[.., 3]            # blue channel

## tf.Tensor(
## [[[ 3.  6.  9. 12.]
##   [15. 18. 21. 24.]
##   [27. 30. 33. 36.]
##   [39. 42. 45. 48.]]], shape=(1, 4, 4), dtype=float32)

# .. expands to all unspecified axes.
op_shape(x@r[])

## shape(1, 4, 4, 3)

op_shape(x@r[..])

## shape(1, 4, 4, 3)

op_shape(x@r[1, ..])

## shape(4, 4, 3)

op_shape(x@r[1, .., 1, 1])

## shape(4)

op_shape(x@r[1, 1, 1, .., 1])

## shape()

# op_subset<- uses the same semantics, but note that not all tensors
# support modification. E.g., TensorFlow constant tensors cannot be modified,
# while TensorFlow Variables can be.

(x <- tensorflow::tf$Variable(matrix(1, nrow = 2, ncol = 3)))

## <tf.Variable 'Variable:0' shape=(2, 3) dtype=float64, numpy=
## array([[1., 1., 1.],
##        [1., 1., 1.]])>

op_subset(x, 1) <- 9
x

## <tf.Variable 'UnreadVariable' shape=(2, 3) dtype=float64, numpy=
## array([[9., 9., 9.],
##        [1., 1., 1.]])>

x@r[1,1] <- 33
x

## <tf.Variable 'UnreadVariable' shape=(2, 3) dtype=float64, numpy=
## array([[33.,  9.,  9.],
##        [ 1.,  1.,  1.]])>

See also

Other core ops:
op_associative_scan()
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_custom_gradient()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_vectorized_map()
op_while_loop()

Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_bitwise_and()
op_bitwise_invert()
op_bitwise_left_shift()
op_bitwise_not()
op_bitwise_or()
op_bitwise_right_shift()
op_bitwise_xor()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_celu()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagflat()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dot_product_attention()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_exp2()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_glu()
op_greater()
op_greater_equal()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_histogram()
op_hstack()
op_identity()
op_ifft2()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inner()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_left_shift()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logdet()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_right_shift()
op_roll()
op_round()
op_rsqrt()
op_saturate_cast()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_split()
op_sqrt()
op_square()
op_squareplus()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tanh_shrink()
op_tensordot()
op_threshold()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_trunc()
op_unravel_index()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()

Other core ops:
op_associative_scan()
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_custom_gradient()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_vectorized_map()
op_while_loop()

Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_bitwise_and()
op_bitwise_invert()
op_bitwise_left_shift()
op_bitwise_not()
op_bitwise_or()
op_bitwise_right_shift()
op_bitwise_xor()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_celu()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagflat()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dot_product_attention()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_exp2()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_glu()
op_greater()
op_greater_equal()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_histogram()
op_hstack()
op_identity()
op_ifft2()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inner()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_left_shift()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logdet()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_right_shift()
op_roll()
op_round()
op_rsqrt()
op_saturate_cast()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_split()
op_sqrt()
op_square()
op_squareplus()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tanh_shrink()
op_tensordot()
op_threshold()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_trunc()
op_unravel_index()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()

Other core ops:
op_associative_scan()
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_custom_gradient()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_vectorized_map()
op_while_loop()

Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_bitwise_and()
op_bitwise_invert()
op_bitwise_left_shift()
op_bitwise_not()
op_bitwise_or()
op_bitwise_right_shift()
op_bitwise_xor()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_celu()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagflat()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dot_product_attention()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_exp2()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_glu()
op_greater()
op_greater_equal()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_histogram()
op_hstack()
op_identity()
op_ifft2()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inner()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_left_shift()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logdet()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_right_shift()
op_roll()
op_round()
op_rsqrt()
op_saturate_cast()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_split()
op_sqrt()
op_square()
op_squareplus()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tanh_shrink()
op_tensordot()
op_threshold()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_trunc()
op_unravel_index()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()