In this example, we’ll build a sequence-to-sequence Transformer model, which we’ll train on an English-to-Spanish machine translation task.

You’ll learn how to:

  • Vectorize text using layer_text_vectorization().
  • Implement a layer_transformer_encoder(), a layer_transformer_decoder(), and a layer_positional_embedding().
  • Prepare data for training a sequence-to-sequence model.
  • Use the trained model to generate translations of never-seen-before input sentences (sequence-to-sequence inference).

The code featured here is adapted from the book Deep Learning with R, Second Edition (chapter 11: Deep learning for text). The present example is fairly barebones, so for detailed explanations of how each building block works, as well as the theory behind Transformers, I recommend reading the book.



library(tensorflow, exclude = c("set_random_seed", "shape"))

Downloading the data

We’ll be working with an English-to-Spanish translation dataset provided by Anki. Let’s download it:

zipfile <- get_file("", origin =

zip::zip_list(zipfile) # See what's in the zipfile
##             filename compressed_size uncompressed_size           timestamp
## 1           spa-eng/               0                 0 2018-06-13 12:21:20
## 2 spa-eng/_about.txt             685              1441 2018-05-13 19:50:34
## 3    spa-eng/spa.txt         2637619           8042772 2018-05-13 19:50:34
##   permissions    crc32 offset
## 1         755 00000000      0
## 2         644 4b18349d     38
## 3         644 63c5c4a2    771
zip::unzip(zipfile, exdir = ".") # unzip into the current directory

text_file <- fs::path("./spa-eng/spa.txt")

Parsing the data

Each line contains an English sentence and its corresponding Spanish sentence. The English sentence is the source sequence and Spanish one is the target sequence. We prepend the token "[start]" and we append the token "[end]" to the Spanish sentence.

text_file <- "spa-eng/spa.txt"
text_pairs <- text_file %>%
  readr::read_tsv(col_names = c("english", "spanish"),
                  col_types = c("cc")) %>%
  within(spanish %<>% paste("[start]", ., "[end]"))

Here’s what our sentence pairs look like:

df <- text_pairs[sample(nrow(text_pairs), 5), ]
glue::glue_data(df, r"(
  english: {english}
  spanish: {spanish}
)") |> cat(sep = "\n\n")
## english: I'm staying in Italy.
## spanish: [start] Me estoy quedando en Italia. [end]
## english: What's so strange about that?
## spanish: [start] ¿Qué es tan extraño acerca de eso? [end]
## english: All of the buses are full.
## spanish: [start] Todos los bondis están llenos. [end]
## english: Is this where your mother works?
## spanish: [start] ¿Es aquí donde trabaja tu madre? [end]
## english: Take precautions.
## spanish: [start] Ten cuidado. [end]

Now, let’s split the sentence pairs into a training set, a validation set, and a test set.

num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")
num_test_samples <- num_val_samples <-
  round(0.15 * nrow(text_pairs))
num_train_samples <- nrow(text_pairs) - num_val_samples - num_test_samples

pair_group <- sample(c(
  rep("train", num_train_samples),
  rep("test", num_test_samples),
  rep("val", num_val_samples)

train_pairs <- text_pairs[pair_group == "train", ]
test_pairs <- text_pairs[pair_group == "test", ]
val_pairs <- text_pairs[pair_group == "val", ]
  {nrow(text_pairs)} total pairs
  {nrow(train_pairs)} training pairs
  {nrow(val_pairs)} validation pairs
  {nrow(test_pairs)} test pairs
)", .transformer = function(text, envir) {
  val <- eval(str2lang(text), envir)
  prettyNum(val, big.mark = ",")
## 118,493 total pairs
## 82,945 training pairs
## 17,774 validation pairs
## 17,774 test pairs

Vectorizing the text data

We’ll use two instances of layer_text_vectorization() to vectorize the text data (one for English and one for Spanish), that is to say, to turn the original strings into integer sequences where each integer represents the index of a word in a vocabulary.

The English layer will use the default string standardization (strip punctuation characters) and splitting scheme (split on whitespace), while the Spanish layer will use a custom standardization, where we add the character "¿" to the set of punctuation characters to be stripped.

Note: in a production-grade machine translation model, I would not recommend stripping the punctuation characters in either language. Instead, I would recommend turning each punctuation character into its own token, which you could achieve by providing a custom split function to layer_text_vectorization().

punctuation_regex <- "[¡¿]|[^[:^punct:][\\]]"
# the regex explained: Match ¡, or ¿, or any punctuation character except ]
# [:^punct:]: is a negated POSIX character class.
# [:punct:] matches any punctuation character, so [:^punct:] matches any
# character that is not a punctuation character.
# [^...] negates the whole character class
# So [^[:^punct:]] would matche any character that is a punctuation character.
# Putting this all together, [^[:^punct:][\\]] matches any
# punctuation character except the ] character.

custom_standardization <- function(input_string) {
  input_string %>%
    tf$strings$lower() %>%
    tf$strings$regex_replace(punctuation_regex, "")

input_string <- as_tensor("[start] ¡corre! [end]")
## tf.Tensor(b'[start] corre [end]', shape=(), dtype=string)
vocab_size <- 15000
sequence_length <- 20

# rename to eng_vectorization
eng_vectorization <- layer_text_vectorization(
  max_tokens = vocab_size,
  output_mode = "int",
  output_sequence_length = sequence_length

spa_vectorization <- layer_text_vectorization(
  max_tokens = vocab_size,
  output_mode = "int",
  output_sequence_length = sequence_length + 1,
  standardize = custom_standardization

adapt(eng_vectorization, train_pairs$english)
adapt(spa_vectorization, train_pairs$spanish)

Next, we’ll format our datasets.

At each training step, the model will seek to predict target words N+1 (and beyond) using the source sentence and the target words from 1 to N.

As such, the training dataset will yield a tuple (inputs, targets), where:

  • inputs is a dictionary (named list) with the keys (names) encoder_inputs and decoder_inputs. encoder_inputs is the vectorized source sentence and encoder_inputs is the target sentence “so far”, that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
  • target is the target sentence offset by one step: it provides the next words in the target sentence – what the model will try to predict.
format_pair <- function(pair) {
  # the vectorization layers requrie batched inputs,
  # reshape scalar string tensor to add a batch dim
  pair %<>% lapply(op_expand_dims, 1)

  # vectorize
  eng <- eng_vectorization(pair$english)
  spa <- spa_vectorization(pair$spanish)

  # drop the batch dim
  eng %<>% tf$ensure_shape(shape(1, sequence_length)) %>% op_squeeze(1)
  spa %<>% tf$ensure_shape(shape(1, sequence_length+1)) %>% op_squeeze(1)

  inputs <- list(encoder_inputs = eng,
                 decoder_inputs = spa[NA:-2])
  targets <- spa[2:NA]
  list(inputs, targets)

batch_size <- 64

library(tfdatasets, exclude = "shape")
make_dataset <- function(pairs) {
  tensor_slices_dataset(pairs) %>%
    dataset_map(format_pair, num_parallel_calls = 4) %>%
    dataset_cache() %>%
    dataset_shuffle(2048) %>%
    dataset_batch(batch_size) %>%
train_ds <- make_dataset(train_pairs)
## Warning: Negative numbers are interpreted python-style when subsetting tensorflow tensors.
## See: ?`[.tensorflow.tensor` for details.
## To turn off this warning, set `options(tensorflow.extract.warn_negatives_pythonic = FALSE)`
val_ds <- make_dataset(val_pairs)

Let’s take a quick look at the sequence shapes (we have batches of 64 pairs, and all sequences are 20 steps long):

c(inputs, targets) %<-% iter_next(as_iterator(train_ds))
## List of 2
##  $ encoder_inputs:<tf.Tensor: shape=(64, 20), dtype=int64, numpy=…>
##  $ decoder_inputs:<tf.Tensor: shape=(64, 20), dtype=int64, numpy=…>
## <tf.Tensor: shape=(64, 20), dtype=int64, numpy=…>

Building the model

Our sequence-to-sequence Transformer consists of a TransformerEncoder and a TransformerDecoder chained together. To make the model aware of word order, we also use a PositionalEmbedding layer.

The source sequence will be pass to the TransformerEncoder, which will produce a new representation of it. This new representation will then be passed to the TransformerDecoder, together with the target sequence so far (target words 1 to N). The TransformerDecoder will then seek to predict the next words in the target sequence (N+1 and beyond).

A key detail that makes this possible is causal masking (see method get_causal_attention_mask() on the TransformerDecoder). The TransformerDecoder sees the entire sequences at once, and thus we must make sure that it only uses information from target tokens 0 to N when predicting token N+1 (otherwise, it could use information from the future, which would result in a model that cannot be used at inference time).

layer_transformer_encoder <- Layer(
  classname = "TransformerEncoder",
  initialize = function(embed_dim, dense_dim, num_heads, ...) {
    self$embed_dim <- embed_dim
    self$dense_dim <- dense_dim
    self$num_heads <- num_heads
    self$attention <-
      layer_multi_head_attention(num_heads = num_heads,
                                 key_dim = embed_dim)

    self$dense_proj <- keras_model_sequential() %>%
      layer_dense(dense_dim, activation = "relu") %>%

    self$layernorm_1 <- layer_layer_normalization()
    self$layernorm_2 <- layer_layer_normalization()
    self$supports_masking <- TRUE

  call = function(inputs, mask = NULL) {
    if (!is.null(mask))
      mask <- mask[, NULL, ] |> op_cast("int32")

    inputs %>%
      { self$attention(., ., attention_mask = mask) + . } %>%
      self$layernorm_1() %>%
      { self$dense_proj(.) + . } %>%

  get_config = function() {
    config <- super$get_config()
    for(name in c("embed_dim", "num_heads", "dense_dim"))
      config[[name]] <- self[[name]]

layer_transformer_decoder <- Layer(
  classname = "TransformerDecoder",

  initialize = function(embed_dim, latent_dim, num_heads, ...) {
    self$embed_dim <- embed_dim
    self$latent_dim <- latent_dim
    self$num_heads <- num_heads
    self$attention_1 <- layer_multi_head_attention(num_heads = num_heads,
                                                   key_dim = embed_dim)
    self$attention_2 <- layer_multi_head_attention(num_heads = num_heads,
                                                   key_dim = embed_dim)
    self$dense_proj <- keras_model_sequential() %>%
      layer_dense(latent_dim, activation = "relu") %>%

    self$layernorm_1 <- layer_layer_normalization()
    self$layernorm_2 <- layer_layer_normalization()
    self$layernorm_3 <- layer_layer_normalization()
    self$supports_masking <- TRUE

  get_config = function() {
    config <- super$get_config()
    for (name in c("embed_dim", "num_heads", "latent_dim"))
      config[[name]] <- self[[name]]

  get_causal_attention_mask = function(inputs) {
    c(batch_size, sequence_length, encoding_length) %<-% op_shape(inputs)

    x <- op_arange(sequence_length)
    i <- x[, NULL]
    j <- x[NULL, ]
    mask <- op_cast(i >= j, "int32")

    repeats <- op_stack(c(batch_size, 1L, 1L))
    op_tile(mask[NULL, , ], repeats)
  call = function(inputs, encoder_outputs, mask = NULL) {
    causal_mask <- self$get_causal_attention_mask(inputs)

    if (is.null(mask))
      mask <- causal_mask
      mask %<>% { op_minimum(op_cast(.[, NULL, ], "int32"),
                             causal_mask) }

    inputs %>%
      { self$attention_1(query = ., value = ., key = .,
                         attention_mask = causal_mask) + . } %>%
      self$layernorm_1() %>%

      { self$attention_2(query = .,
                         value = encoder_outputs,
                         key = encoder_outputs,
                         attention_mask = mask) + . } %>%
      self$layernorm_2() %>%

      { self$dense_proj(.) + . } %>%


layer_positional_embedding <- Layer(
  classname = "PositionalEmbedding",

  initialize = function(sequence_length, vocab_size, embed_dim, ...) {
    self$token_embeddings <- layer_embedding(
      input_dim = vocab_size, output_dim = embed_dim
    self$position_embeddings <- layer_embedding(
      input_dim = sequence_length, output_dim = embed_dim
    self$sequence_length <- sequence_length
    self$vocab_size <- vocab_size
    self$embed_dim <- embed_dim

  call = function(inputs) {
    c(., len) %<-% op_shape(inputs) # (batch_size, seq_len)
    positions <- op_arange(0, len, dtype = "int32")
    embedded_tokens <- self$token_embeddings(inputs)
    embedded_positions <- self$position_embeddings(positions)
    embedded_tokens + embedded_positions

  compute_mask = function(inputs, mask = NULL) {
    if (is.null(mask)) return (NULL)
    inputs != 0L

  get_config = function() {
    config <- super$get_config()
    for(name in c("sequence_length", "vocab_size", "embed_dim"))
      config[[name]] <- self[[name]]

Next, we assemble the end-to-end model.

embed_dim <- 256
latent_dim <- 2048
num_heads <- 8

encoder_inputs <- layer_input(shape(NA), dtype = "int64",
                              name = "encoder_inputs")
encoder_outputs <- encoder_inputs %>%
  layer_positional_embedding(sequence_length, vocab_size, embed_dim) %>%
  layer_transformer_encoder(embed_dim, latent_dim, num_heads)

encoder <- keras_model(encoder_inputs, encoder_outputs)

decoder_inputs <-  layer_input(shape(NA), dtype = "int64",
                               name = "decoder_inputs")
encoded_seq_inputs <- layer_input(shape(NA, embed_dim),
                                  name = "decoder_state_inputs")

transformer_decoder <- layer_transformer_decoder(NULL,
  embed_dim, latent_dim, num_heads)

decoder_outputs <- decoder_inputs %>%
  layer_positional_embedding(sequence_length, vocab_size, embed_dim) %>%
  transformer_decoder(., encoded_seq_inputs) %>%
  layer_dropout(0.5) %>%
  layer_dense(vocab_size, activation="softmax")

decoder <- keras_model(inputs = list(decoder_inputs, encoded_seq_inputs),
                       outputs = decoder_outputs)

decoder_outputs = decoder(list(decoder_inputs, encoder_outputs))

transformer <- keras_model(list(encoder_inputs, decoder_inputs),
                           name = "transformer")

Training our model

We’ll use accuracy as a quick way to monitor training progress on the validation data. Note that machine translation typically uses BLEU scores as well as other metrics, rather than accuracy.

Here we only train for 1 epoch, but to get the model to actually converge you should train for at least 30 epochs.

epochs <- 1  # This should be at least 30 for convergence

## Model: "transformer"
## ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
## ┃ Layer (type)         Output Shape          Param #  Connected to      
## ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
## │ encoder_inputs      │ (None, None)      │          0 │ -                 │
## │ (InputLayer)        │                   │            │                   │
## ├─────────────────────┼───────────────────┼────────────┼───────────────────┤
## │ positional_embeddi… │ (None, None, 256) │  3,845,120 │ encoder_inputs[0… │
## │ (PositionalEmbeddi… │                   │            │                   │
## ├─────────────────────┼───────────────────┼────────────┼───────────────────┤
## │ decoder_inputs      │ (None, None)      │          0 │ -                 │
## │ (InputLayer)        │                   │            │                   │
## ├─────────────────────┼───────────────────┼────────────┼───────────────────┤
## │ transformer_encoder │ (None, None, 256) │  3,155,456 │ positional_embed… │
## │ (TransformerEncode… │                   │            │                   │
## ├─────────────────────┼───────────────────┼────────────┼───────────────────┤
## │ functional_3        │ (None, None,      │ 12,959,640 │ decoder_inputs[0… │
## │ (Functional)        │ 15000)            │            │ transformer_enco… │
## └─────────────────────┴───────────────────┴────────────┴───────────────────┘
##  Total params: 19,960,216 (76.14 MB)
##  Trainable params: 19,960,216 (76.14 MB)
##  Non-trainable params: 0 (0.00 B)
transformer |> compile(
  loss = "sparse_categorical_crossentropy",
  metrics = "accuracy"
transformer |> fit(train_ds, epochs = epochs,
                   validation_data = val_ds)
## 1297/1297 - 48s - 37ms/step - accuracy: 0.7463 - loss: 1.8009 - val_accuracy: 0.7657 - val_loss: 1.5766

Decoding test sentences

Finally, let’s demonstrate how to translate brand new English sentences. We simply feed into the model the vectorized English sentence as well as the target token "[start]", then we repeatedly generated the next token, until we hit the token "[end]".

spa_vocab <- spa_vectorization |> get_vocabulary()
max_decoded_sentence_length <- 20
tf_decode_sequence <- tf_function(function(input_sentence) {
  withr::local_options( = "python")

  tokenized_input_sentence <- input_sentence %>%
    as_tensor(shape = c(1, 1)) %>%
  spa_vocab <- as_tensor(spa_vocab)
  decoded_sentence <- as_tensor("[start]", shape = c(1, 1))

  for (i in tf$range(as.integer(max_decoded_sentence_length))) {

    tokenized_target_sentence <-

    next_token_predictions <-

    sampled_token_index <- tf$argmax(next_token_predictions[0, i, ])
    sampled_token <- spa_vocab[sampled_token_index]
    decoded_sentence <-
      tf$strings$join(c(decoded_sentence, sampled_token),
                      separator = " ")

    if (sampled_token == "[end]")



for (i in seq(20)) {

    c(input_sentence, correct_translation) %<-%
      test_pairs[, 1), ]
    cat("English:", input_sentence, "\n")
    cat("Correct Translation:", tolower(correct_translation), "\n")
    cat("  Model Translation:", input_sentence %>% as_tensor() %>%
          tf_decode_sequence() %>% as.character(), "\n")

After 30 epochs, we get results such as:

English: I'm sure everything will be fine.
Correct Translation: [start] estoy segura de que todo irá bien. [end]
  Model Translation: [start] estoy seguro de que todo va bien [end]