This decorator allows fine grained control over the gradients of a sequence for operations. This may be useful for multiple reasons, including providing a more efficient or numerically stable gradient for a sequence of operations.
Arguments
- f
Function
f(...)that returns a tuple(output, grad_fn)where:...is a sequence of unnamed arguments, each a tensor input or nested structure of tensor inputs to the function.outputis a (potentially nested structure of) tensor outputs of applying operations in forward_fnf()to....grad_fnis a function with the signaturegrad_fn(..., upstream)which returns a list of tensors the same size as (flattened)...: the derivatives of tensors inoutputwith respect to the tensors in....upstreamis a tensor or sequence of tensors holding the initial value gradients for each tensor inoutput.
Value
A function h(...) which returns the same value as f(...)[[1]] and whose
gradient is determined by f(...)[[2]].
Note
Note that the grad function that returns gradient computation
requires ... as well as an upstream named argument, depending
on the backend being set. With the JAX and TensorFlow backends,
it requires only one argument, whereas it might use the upstream
argument in the case of the PyTorch backend.
When working with TensorFlow/JAX backend, grad(upstream)
is sufficient. With PyTorch, the grad function requires
... as well as upstream, e.g. grad <- \(..., upstream).
Follow the example above to use op_custom_gradient() in
a way that is compatible with all backends.
Example
Backend-agnostic example.
log1pexp <- op_custom_gradient(\(x) {
e <- op_exp(x)
grad <- function(..., upstream = NULL) {
upstream <- upstream %||% ..1
op_multiply(upstream, 1.0 - 1.0 / op_add(1, e))
}
tuple(op_log(1 + e), grad)
})
if(config_backend() == "tensorflow") {
tf <- tensorflow::tf
x <- op_convert_to_tensor(100.0)
with(tf$GradientTape() %as% tape, {
tape$watch(x)
y <- log1pexp(x)
})
dy_dx <- tape$gradient(y, x)
stopifnot(as.numeric(dy_dx) == 1)
}See also
Other core ops: op_associative_scan() op_cast() op_cond() op_convert_to_numpy() op_convert_to_tensor() op_dtype() op_fori_loop() op_is_tensor() op_map() op_rearrange() op_scan() op_scatter() op_scatter_update() op_searchsorted() op_shape() op_slice() op_slice_update() op_stop_gradient() op_subset() op_switch() op_unstack() op_vectorized_map() op_while_loop()
Other ops: op_abs() op_add() op_all() op_any() op_append() op_arange() op_arccos() op_arccosh() op_arcsin() op_arcsinh() op_arctan() op_arctan2() op_arctanh() op_argmax() op_argmin() op_argpartition() op_argsort() op_array() op_associative_scan() op_average() op_average_pool() op_batch_normalization() op_binary_crossentropy() op_bincount() op_bitwise_and() op_bitwise_invert() op_bitwise_left_shift() op_bitwise_not() op_bitwise_or() op_bitwise_right_shift() op_bitwise_xor() op_broadcast_to() op_cast() op_categorical_crossentropy() op_ceil() op_celu() op_cholesky() op_clip() op_concatenate() op_cond() op_conj() op_conv() op_conv_transpose() op_convert_to_numpy() op_convert_to_tensor() op_copy() op_correlate() op_cos() op_cosh() op_count_nonzero() op_cross() op_ctc_decode() op_ctc_loss() op_cumprod() op_cumsum() op_depthwise_conv() op_det() op_diag() op_diagflat() op_diagonal() op_diff() op_digitize() op_divide() op_divide_no_nan() op_dot() op_dot_product_attention() op_dtype() op_eig() op_eigh() op_einsum() op_elu() op_empty() op_equal() op_erf() op_erfinv() op_exp() op_exp2() op_expand_dims() op_expm1() op_extract_sequences() op_eye() op_fft() op_fft2() op_flip() op_floor() op_floor_divide() op_fori_loop() op_full() op_full_like() op_gelu() op_get_item() op_glu() op_greater() op_greater_equal() op_hard_shrink() op_hard_sigmoid() op_hard_silu() op_hard_tanh() op_histogram() op_hstack() op_identity() op_ifft2() op_imag() op_image_affine_transform() op_image_crop() op_image_extract_patches() op_image_gaussian_blur() op_image_hsv_to_rgb() op_image_map_coordinates() op_image_pad() op_image_perspective_transform() op_image_resize() op_image_rgb_to_grayscale() op_image_rgb_to_hsv() op_in_top_k() op_inner() op_inv() op_irfft() op_is_tensor() op_isclose() op_isfinite() op_isinf() op_isnan() op_istft() op_leaky_relu() op_left_shift() op_less() op_less_equal() op_linspace() op_log() op_log10() op_log1p() op_log2() op_log_sigmoid() op_log_softmax() op_logaddexp() op_logdet() op_logical_and() op_logical_not() op_logical_or() op_logical_xor() op_logspace() op_logsumexp() op_lstsq() op_lu_factor() op_map() op_matmul() op_max() op_max_pool() op_maximum() op_mean() op_median() op_meshgrid() op_min() op_minimum() op_mod() op_moments() op_moveaxis() op_multi_hot() op_multiply() op_nan_to_num() op_ndim() op_negative() op_nonzero() op_norm() op_normalize() op_not_equal() op_one_hot() op_ones() op_ones_like() op_outer() op_pad() op_polar() op_power() op_prod() op_psnr() op_qr() op_quantile() op_ravel() op_real() op_rearrange() op_reciprocal() op_relu() op_relu6() op_repeat() op_reshape() op_rfft() op_right_shift() op_rms_normalization() op_roll() op_rot90() op_round() op_rsqrt() op_saturate_cast() op_scan() op_scatter() op_scatter_update() op_searchsorted() op_segment_max() op_segment_sum() op_select() op_selu() op_separable_conv() op_shape() op_sigmoid() op_sign() op_signbit() op_silu() op_sin() op_sinh() op_size() op_slice() op_slice_update() op_slogdet() op_soft_shrink() op_softmax() op_softplus() op_softsign() op_solve() op_solve_triangular() op_sort() op_sparse_categorical_crossentropy() op_sparse_plus() op_sparsemax() op_split() op_sqrt() op_square() op_squareplus() op_squeeze() op_stack() op_std() op_stft() op_stop_gradient() op_subset() op_subtract() op_sum() op_svd() op_swapaxes() op_switch() op_take() op_take_along_axis() op_tan() op_tanh() op_tanh_shrink() op_tensordot() op_threshold() op_tile() op_top_k() op_trace() op_transpose() op_tri() op_tril() op_triu() op_trunc() op_unravel_index() op_unstack() op_var() op_vdot() op_vectorize() op_vectorized_map() op_vstack() op_where() op_while_loop() op_zeros() op_zeros_like()