Skip to contents

Introduction

The Keras distribution API is a new interface designed to facilitate distributed deep learning across a variety of backends like JAX, TensorFlow and PyTorch. This powerful API introduces a suite of tools enabling data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. Whether leveraging the power of GPUs or TPUs, the API provides a streamlined approach to initializing distributed environments, defining device meshes, and orchestrating the layout of tensors across computational resources. Through classes like DataParallel and ModelParallel, it abstracts the complexity involved in parallel computation, making it easier for developers to accelerate their machine learning workflows.

How it works

The Keras distribution API provides a global programming model that allows developers to compose applications that operate on tensors in a global context (as if working with a single device) while automatically managing distribution across many devices. The API leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion.

By decoupling the application from sharding directives, the API enables running the same application on a single device, multiple devices, or even multiple clients, while preserving its global semantics.

Setup

# This guide assumes there are 8 GPUs available for testing. If you don't have
# 8 gpus available locally, you can set the following envvar to
# make xla initialize the CPU as 8 devices, to enable local testing
Sys.setenv("CUDA_VISIBLE_DEVICES" = "")
Sys.setenv("XLA_FLAGS" = "--xla_force_host_platform_device_count=8")
library(keras3)

# The distribution API is only implemented for the JAX backend for now.
use_backend("jax")
jax <- reticulate::import("jax")

library(tfdatasets, exclude = "shape") # For dataset input.

DeviceMesh and TensorLayout

The keras$distribution$DeviceMesh class in Keras distribution API represents a cluster of computational devices configured for distributed computation. It aligns with similar concepts in jax.sharding.Mesh and tf.dtensor.Mesh, where it’s used to map the physical devices to a logical mesh structure.

The TensorLayout class then specifies how tensors are distributed across the DeviceMesh, detailing the sharding of tensors along specified axes that correspond to the names of the axes in the DeviceMesh.

You can find more detailed concept explainers in the TensorFlow DTensor guide.

# Retrieve the local available gpu devices.
devices <- jax$devices() # "gpu"
str(devices)
## List of 8
##  $ :TFRT_CPU_0
##  $ :TFRT_CPU_1
##  $ :TFRT_CPU_2
##  $ :TFRT_CPU_3
##  $ :TFRT_CPU_4
##  $ :TFRT_CPU_5
##  $ :TFRT_CPU_6
##  $ :TFRT_CPU_7
# Define a 2x4 device mesh with data and model parallel axes
mesh <- keras$distribution$DeviceMesh(
  shape = shape(2, 4),
  axis_names = list("data", "model"),
  devices = devices
)

# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d <- keras$distribution$TensorLayout(
  axes = c("model", "data"),
  device_mesh = mesh
)

# A 4D layout which could be used for data parallelism of an image input.
replicated_layout_4d <- keras$distribution$TensorLayout(
  axes = list("data", NULL, NULL, NULL),
  device_mesh = mesh
)

Distribution

The Distribution class in Keras serves as a foundational abstract class designed for developing custom distribution strategies. It encapsulates the core logic needed to distribute a model’s variables, input data, and intermediate computations across a device mesh. As an end user, you won’t have to interact directly with this class, but its subclasses like DataParallel or ModelParallel.

DataParallel

The DataParallel class in the Keras distribution API is designed for the data parallelism strategy in distributed training, where the model weights are replicated across all devices in the DeviceMesh, and each device processes a portion of the input data.

Here is a sample usage of this class.

# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel <- DataParallel()
data_parallel <- keras$distribution$DataParallel(devices = devices)

# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d <- keras$distribution$DeviceMesh(
  shape = shape(8),
  axis_names = list("data"),
  devices = devices
)
data_parallel <- keras$distribution$DataParallel(device_mesh = mesh_1d)

inputs <- random_normal(c(128, 28, 28, 1))
labels <- random_normal(c(128, 10))
dataset <- tensor_slices_dataset(c(inputs, labels)) |>
  dataset_batch(16)

# Set the global distribution.
keras$distribution$set_distribution(data_parallel)

# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model |> fit()` or
# `model |> evaluate()` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregation of losses,
# since all the computation happens in a global context.
inputs <- keras_input(shape = c(28, 28, 1))
outputs <- inputs |>
  layer_flatten() |>
  layer_dense(units = 200, use_bias = FALSE, activation = "relu") |>
  layer_dropout(0.4) |>
  layer_dense(units = 10, activation = "softmax")

model <- keras_model(inputs = inputs, outputs = outputs)

model |> compile(loss = "mse")
model |> fit(dataset, epochs = 3)
## Epoch 1/3
## 8/8 - 0s - 38ms/step - loss: 1.0768
## Epoch 2/3
## 8/8 - 0s - 6ms/step - loss: 0.9754
## Epoch 3/3
## 8/8 - 0s - 5ms/step - loss: 0.9347
model |> evaluate(dataset)
## 8/8 - 0s - 7ms/step - loss: 0.8936
## $loss
## [1] 0.8935966

ModelParallel and LayoutMap

ModelParallel will be mostly useful when model weights are too large to fit on a single accelerator. This setting allows you to spit your model weights or activation tensors across all the devices on the DeviceMesh, and enable the horizontal scaling for the large models.

Unlike the DataParallel model where all weights are fully replicated, the weights layout under ModelParallel usually need some customization for best performances. We introduce LayoutMap to let you specify the TensorLayout for any weights and intermediate tensors from global perspective.

LayoutMap is a dict-like object that maps a string to TensorLayout instances. It behaves differently from a normal dict in that the string key is treated as a regex when retrieving the value. The class allows you to define the naming schema of TensorLayout and then retrieve the corresponding TensorLayout instance. Typically, the key used to query is the variable$path attribute, which is the identifier of the variable. As a shortcut, a list of axis names is also allowed when inserting a value, and it will be converted to TensorLayout.

The LayoutMap can also optionally contain a DeviceMesh to populate the TensorLayout$device_mesh if it is not set. When retrieving a layout with a key, and if there isn’t an exact match, all existing keys in the layout map will be treated as regex and matched against the input key again. If there are multiple matches, a ValueError is raised. If no matches are found, NULL is returned.

mesh_2d <- keras$distribution$DeviceMesh(
  shape = shape(2, 4),
  axis_names = c("data", "model"),
  devices = devices
)
layout_map  <- keras$distribution$LayoutMap(mesh_2d)

# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] <- tuple(NULL, "model")
layout_map["d1/bias"] <- tuple("model")

# You can also set the layout for the layer output like
layout_map["d2/output"] <- tuple("data", NULL)

model_parallel <- keras$distribution$ModelParallel(
  layout_map = layout_map, batch_dim_name = "data"
)

keras$distribution$set_distribution(model_parallel)

inputs <- layer_input(shape = c(28, 28, 1))
outputs <- inputs |>
  layer_flatten() |>
  layer_dense(units = 200, use_bias = FALSE,
              activation = "relu", name = "d1") |>
  layer_dropout(0.4) |>
  layer_dense(units = 10,
              activation = "softmax",
              name = "d2")

model <- keras_model(inputs = inputs, outputs = outputs)

# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model |> compile(loss = "mse")
model |> fit(dataset, epochs = 3)
## Epoch 1/3
## 8/8 - 0s - 29ms/step - loss: 1.0836
## Epoch 2/3
## 8/8 - 0s - 4ms/step - loss: 1.0192
## Epoch 3/3
## 8/8 - 0s - 4ms/step - loss: 0.9821
model |> evaluate(dataset)
## 8/8 - 0s - 8ms/step - loss: 0.9576
## $loss
## [1] 0.9576273

It is also easy to change the mesh structure to tune the computation between more data parallel or model parallel. You can do this by adjusting the shape of the mesh. And no changes are needed for any other code.

full_data_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(8, 1),
  axis_names = list("data", "model"),
  devices = devices
)
more_data_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(4, 2),
  axis_names = list("data", "model"),
  devices = devices
)
more_model_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(2, 4),
  axis_names = list("data", "model"),
  devices = devices
)
full_model_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(1, 8),
  axis_names = list("data", "model"),
  devices = devices
)