Japanese-English Machine Translation Model with Transformer & PyTorch
A tutorial using Jupyter Notebook, PyTorch, Torchtext, and SentencePiece
This article shows how to train a Japanese to English machine translation model using a Sequence to Sequence Transformer model. This tutorial is heavily based on a tutorial by PyTorch, with some modifications added for using a custom dataset (JParaCrawl), handling Japanese scripts using SentencePiece tokenizer, and saving the model.
If you want to know more about Transformers and Sequence to Sequence Machine Learning, I personally recommend this Medium article, or you can also read the paper.
Import required packages
Firstly, let’s make sure we have the below packages installed in our system, if you found that some packages are missing, make sure to install them. Here are some links that might be useful for installation: SentencePiece, Torch, and TorchText (Note: if you are using Google Colab, I assume you only need to install SentencePiece as an additional package).
import math
import torchtext
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoaderfrom collections import Counter
from torchtext.vocab import Vocabfrom torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayerimport io
import time
import pandas as pd
import numpy as np
import pickleimport sentencepiece as spmtorch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.get_device_name(0))
Get the parallel dataset
In this tutorial, we will use the Japanese-English parallel dataset downloaded from JParaCrawl which is described as the “largest publicly available English-Japanese parallel corpus created by NTT. It was created by largely crawling the web and automatically aligning parallel sentences.” You can also see the paper here.
df = pd.read_csv('en-ja.bicleaner05.txt', sep='\\t', engine='python', header=None)trainen = df[2].values.tolist()
trainja = df[3].values.tolist()trainen.pop(5973072)
trainja.pop(5973072)
After importing all the Japanese and their English counterparts, I deleted the last data in the dataset because it has a missing value. In total, the number of sentences in both trainen and trainja is 5,973,071, however, for learning purposes, it is often recommended to sample the data and make sure everything is working as intended, before using all the data at once, to save time.
Here is an example of sentence contained in the dataset.
print(trainen[100005])
print(trainja[100005])
All residents aged 20 to 59 years who live in Japan must enroll in public pension system.
年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。
We can also use different parallel datasets to follow along with this article, just make sure that we can process the data into the two lists of strings as shown above, containing the Japanese and English sentences.
Prepare the tokenizers
Unlike English or other alphabetical languages, a Japanese sentence does not contain whitespaces to separate the words. We can use the tokenizers provided by JParaCrawl which was created using SentencePiece for both Japanese and English, you can visit the JParaCrawl website to download them, or click here.
en_tokenizer = spm.SentencePieceProcessor(model_file='spm.en.nopretok.model')ja_tokenizer = spm.SentencePieceProcessor(model_file='spm.ja.nopretok.model')
After the tokenizers are loaded, you can test them, for example, by executing the below code.
en_tokenizer.encode("All residents aged 20 to 59 years who live in Japan must enroll in public pension system.", out_type='str')
[‘▁All’,’▁residents’,’▁aged’,’▁20',’▁to’,’▁59',’▁years’,’▁who’,’▁live’,’▁in’,’▁Japan’,’▁must’,’▁enroll’,’▁in’,’▁public’,’▁pension’,’▁system’,’.’]
ja_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type='str')
[‘▁’,’年’,’金’,’▁日本’,’に住んでいる’,’20',’歳’,’~’,’60',’歳の’,’全ての’,’人は’,’、’,’公的’,’年’,’金’,’制度’,’に’,’加入’,’しなければなりません’,’。’]
Build the TorchText Vocab objects and convert the sentences into Torch tensors
Using the tokenizers and raw sentences, we then build the Vocab object imported from TorchText. This process can take a few seconds or minutes depending on the size of our dataset and computing power. Different tokenizer can also affect the time needed to build the vocab, I tried several other tokenizers for Japanese but SentencePiece seems to be working well and fast enough for me.
def build_vocab(sentences, tokenizer):
counter = Counter()
for sentence in sentences:
counter.update(tokenizer.encode(sentence, out_type=str))
return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])ja_vocab = build_vocab(trainja, ja_tokenizer)
en_vocab = build_vocab(trainen, en_tokenizer)
After we have the vocabulary objects, we can then use the vocab and the tokenizer objects to build the tensors for our training data.
def data_process(ja, en):
data = []
for (raw_ja, raw_en) in zip(ja, en):
ja_tensor_ = torch.tensor([ja_vocab[token] for token in ja_tokenizer.encode(raw_ja.rstrip("\n"), out_type=str)],
dtype=torch.long)
en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer.encode(raw_en.rstrip("\n"), out_type=str)],
dtype=torch.long)
data.append((ja_tensor_, en_tensor_))
return datatrain_data = data_process(trainja, trainen)
Create the DataLoader object to be iterated during training
Here, I set the BATCH_SIZE to 16 to prevent “cuda out of memory”, but this depends on various things such as your machine memory capacity, size of data, etc., so feel free to change the batch size according to your needs (note: the tutorial from PyTorch sets the batch size as 128 using the Multi30k German-English dataset.)
BATCH_SIZE = 16
PAD_IDX = ja_vocab['<pad>']
BOS_IDX = ja_vocab['<bos>']
EOS_IDX = ja_vocab['<eos>']def generate_batch(data_batch):
ja_batch, en_batch = [], []
for (ja_item, en_item) in data_batch:
ja_batch.append(torch.cat([torch.tensor([BOS_IDX]), ja_item, torch.tensor([EOS_IDX])], dim=0))
en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
ja_batch = pad_sequence(ja_batch, padding_value=PAD_IDX)
en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
return ja_batch, en_batchtrain_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=generate_batch)
Sequence-to-sequence Transformer
The next couple of codes and text explanations (written in italic) are taken from the original PyTorch tutorial. I did not make any change except for the BATCH_SIZE and the word de_vocab
which is changed to ja_vocab.
Transformer is a Seq2Seq model introduced in “Attention is all you need” paper for solving machine translation task. Transformer model consists of an encoder and decoder block each containing fixed number of layers.
Encoder processes the input sequence by propagating it, through a series of Multi-head Attention and Feed forward network layers. The output from the Encoder referred to as memory
, is fed to the decoder along with target tensors. Encoder and decoder are trained in an end-to-end fashion using teacher forcing technique.
from torch.nn import (TransformerEncoder, TransformerDecoder,
TransformerEncoderLayer, TransformerDecoderLayer)
class Seq2SeqTransformer(nn.Module):
def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
dim_feedforward:int = 512, dropout:float = 0.1):
super(Seq2SeqTransformer, self).__init__()
encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
dim_feedforward=dim_feedforward)
self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
dim_feedforward=dim_feedforward)
self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
self.generator = nn.Linear(emb_size, tgt_vocab_size)
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
tgt_mask: Tensor, src_padding_mask: Tensor,
tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
src_emb = self.positional_encoding(self.src_tok_emb(src))
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
tgt_padding_mask, memory_key_padding_mask)
return self.generator(outs)
def encode(self, src: Tensor, src_mask: Tensor):
return self.transformer_encoder(self.positional_encoding(
self.src_tok_emb(src)), src_mask)
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
return self.transformer_decoder(self.positional_encoding(
self.tgt_tok_emb(tgt)), memory,
tgt_mask)
Text tokens are represented by using token embeddings. Positional encoding is added to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
super(PositionalEncoding, self).__init__()
den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
pos_embedding = torch.zeros((maxlen, emb_size))
pos_embedding[:, 0::2] = torch.sin(pos * den)
pos_embedding[:, 1::2] = torch.cos(pos * den)
pos_embedding = pos_embedding.unsqueeze(-2)
self.dropout = nn.Dropout(dropout)
self.register_buffer('pos_embedding', pos_embedding)
def forward(self, token_embedding: Tensor):
return self.dropout(token_embedding +
self.pos_embedding[:token_embedding.size(0),:])
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size
def forward(self, tokens: Tensor):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
We create a subsequent word
mask to stop a target word from attending to its subsequent words. We also create masks, for masking source and target padding tokens
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_mask(src, tgt):
src_seq_len = src.shape[0]
tgt_seq_len = tgt.shape[0]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)
src_padding_mask = (src == PAD_IDX).transpose(0, 1)
tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
Define model parameters and instantiate model.
SRC_VOCAB_SIZE = len(ja_vocab)
TGT_VOCAB_SIZE = len(en_vocab)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 16
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
NUM_EPOCHS = 16transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
FFN_HID_DIM)
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
transformer = transformer.to(device)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(
transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)def train_epoch(model, train_iter, optimizer):
model.train()
losses = 0
for idx, (src, tgt) in enumerate(train_iter):
src = src.to(device)
tgt = tgt.to(device)
tgt_input = tgt[:-1, :]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
logits = model(src, tgt_input, src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask, src_padding_mask)
optimizer.zero_grad()
tgt_out = tgt[1:,:]
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
loss.backward()
optimizer.step()
losses += loss.item()
return losses / len(train_iter)
def evaluate(model, val_iter):
model.eval()
losses = 0
for idx, (src, tgt) in (enumerate(valid_iter)):
src = src.to(device)
tgt = tgt.to(device)
tgt_input = tgt[:-1, :]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
logits = model(src, tgt_input, src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask, src_padding_mask)
tgt_out = tgt[1:,:]
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
losses += loss.item()
return losses / len(val_iter)
Start training
Finally, after preparing the necessary classes and functions, we are ready to train our model. This goes without saying but the time needed to finish training could vary greatly depending on a lot of things such as computing power, parameters, and size of datasets.
When I trained the model using the complete list of sentences from JParaCrawl which has around 5.9 million sentences for each language, it took around 5 hours per epoch using a single NVIDIA GeForce RTX 3070 GPU.
Here is the code:
for epoch in range(1, NUM_EPOCHS+1):
start_time = time.time()
train_loss = train_epoch(transformer, train_iter, optimizer)
end_time = time.time()
print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "
f"Epoch time = {(end_time - start_time):.3f}s"))
Running this script, every time one epoch finishes, it will print the epoch number, training loss, and the time needed to finish that epoch. Something like this:
Epoch: 1, Train loss: 5.xxx, Epoch time = xxxs
Epoch: 2, Train loss: 4.xxx, Epoch time = xxxs
Epoch: 3, Train loss: 4.xxx, Epoch time = xxxs
...
Epoch: 15, Train loss: 2.xxx, Epoch time = xxxs
Epoch: 16, Train loss: 2.xxx, Epoch time = xxxs
Try translating a Japanese sentence using the trained model
First, we create the functions to translate a new sentence, including steps such as to get the Japanese sentence, tokenize, convert to tensors, inference, and then decode the result back into a sentence, but this time in English.
def greedy_decode(model, src, src_mask, max_len, start_symbol):
src = src.to(device)
src_mask = src_mask.to(device)memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
for i in range(max_len-1):
memory = memory.to(device)
memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
tgt_mask = (generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).to(device)
out = model.decode(ys, memory, tgt_mask)
out = out.transpose(0, 1)
prob = model.generator(out[:, -1])
_, next_word = torch.max(prob, dim = 1)
next_word = next_word.item()ys = torch.cat([ys,
torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
if next_word == EOS_IDX:
break
return ysdef translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
model.eval()
tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
num_tokens = len(tokens)
src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
Then, we can just call the translate function and pass the required parameters.
translate(transformer, "嘘は続くでしょう。", ja_vocab, en_vocab, ja_tokenizer)
Output: ‘ ▁The ▁lies ▁will ▁continue . ‘
Save the Vocab objects and trained model
Finally, after the training has finished, we will save the Vocab objects (en_vocab and ja_vocab) first, using Pickle.
import pickle
# open a file, where you want to store the data
file = open('en_vocab.pkl', 'wb')
# dump information to that file
pickle.dump(en_vocab, file)
file.close()file = open('ja_vocab.pkl', 'wb')
pickle.dump(ja_vocab, file)
file.close()
Lastly, we can also save the model for later use using PyTorch save and load functions. Generally, there are two ways to save the model depending what we want to use them for later.
The first one is for inference only, we can load the model later and use it to translate from Japanese to English.
# save model for inference
torch.save(transformer.state_dict(), 'inference_model')
The second one is for inference too, but also for when we want to load the model later, and want to resume the training.
# save model + checkpoint to resume training later
torch.save({
'epoch': NUM_EPOCHS,
'model_state_dict': transformer.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
}, 'model_checkpoint.tar')
Conclusion
That’s it!
I hope after reaching the end of this article, we would be able to train a transformer model for Japanese-English neural machine translation using PyTorch. If there is something missing or unclear, or if you know better ways to do this, please let me know in the comment section. And of course, any constructive feedback is welcome!