Open In App

Text Generation using Fnet

Transformer-based models excel in understanding and processing sequences due to their utilization of a mechanism known as “self-attention.” This involves scrutinizing each token to discern its relationship with every other token in the sequence. Despite the effectiveness of self-attention, its drawback lies in its computational cost. For a sequence of length N, self-attention requires N^2 operations, resulting in quadratic scaling. This can be computationally expensive and time-consuming, especially for long sentences, imposing limitations on sequence length, such as the 512-token constraint in the standard BERT model.

Numerous methods have emerged to address the computational inefficiency of quadratic scaling. A recent innovation tackling this challenge is FNet, which completely replaces the self-attention layer. FNet introduces an alternative mechanism, diverging from the traditional self-attention paradigm while aiming to achieve comparable or enhanced performance in handling sequences. In this article, we will focus on the implementation of the FNet architecture for text generation in Python using Pytorch.



FNet

The Transformer architecture is renowned for its dominance in natural language processing (NLP). It uses a core component, the attention mechanism, which connects input tokens by weighing their relevance to each other. While various studies have probed the Transformer and its attention sublayers, the computational cost of self-attention remains a challenge, particularly for long sequences.

In response to this challenge, a recent innovation, FNet, introduces a novel approach by replacing the self-attention layer entirely. Instead of self-attention, FNet utilizes simpler token mixing mechanisms, such as parameterized matrix multiplications and, remarkably, the Fourier transform. Unlike traditional self-attention, the Fourier transform has no parameters yet achieves comparable performance, scaling efficiently to long sequences due to the Fast Fourier transform (FFT) algorithm.



Text Generation using FNet

Step 1: Libraries and import

Install below libraries if they are not available in your environment

!pip install datasets
!pip install torch[transformers]

Declare device variable for computation on GPU if available




import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

Output:

cuda

Step 2 : Load Data

Here we will use the wikitext corpus for training the data. We will use the dataset library to load the same




from datasets import load_dataset
datasets = load_dataset('wikitext','wikitext-2-raw-v1')

Step 3: Data Preprocessing

We will clean up our data




import re
 
 
def preprocess_text(sentence):
    # lowering the sentence and storing in text vaiable
    text = sentence['text'].lower()
    # removing other than characters and punctuations
    text = re.sub('[^a-z?!.,]', ' ', text)
    text = re.sub('\s\s+', ' ', text)  # removing double spaces
    sentence['text'] = text
    return sentence
 
 
datasets['train'] = datasets['train'].map(preprocess_text)
datasets['test'] = datasets['test'].map(preprocess_text)
datasets['validation'] = datasets['validation'].map(preprocess_text)
 
datasets['train'] = datasets['train'].filter(lambda x: len(x['text']) > 20)
datasets['test'] = datasets['test'].filter(lambda x: len(x['text']) > 20)
datasets['validation'] = datasets['validation'].filter(
    lambda x: len(x['text']) > 20)

Step 4. : Tokenisation




from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer
 
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
 
 
# Tokenizer
def tokenize(sentence):
    sentence = tokenizer(sentence['text'], truncation=True)
    return sentence
 
 
tokenized_inputs = datasets['test'].map(tokenize)
tokenized_inputs = tokenized_inputs.remove_columns(['text'])
 
 
# DataCollator
batch = 16
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, padding=True, return_tensors="pt")
dataloader = DataLoader(
    tokenized_inputs, batch_size=batch, collate_fn=data_collator)

Step 5 : Embedding Positional encoding




import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft as fft
import numpy as np
import pandas as pd
 
class PositionalEncoding(torch.nn.Module):
 
 
    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.d_model = d_model
        self.max_sequence_length = max_sequence_length
        self.positional_encoding = self.create_positional_encoding().to(device)
 
    def create_positional_encoding(self):
 
        # Initialize positional encoding matrix
        positional_encoding = np.zeros((self.max_sequence_length, self.d_model))
 
        # Calculate positional encoding for each position and each dimension
        for pos in range(self.max_sequence_length):
            for i in range(0, self.d_model, 2):
                # Apply sin to even indices in the array; indices in Python start at 0 so i is even.
                positional_encoding[pos, i] = np.sin(pos / (10000 ** ((2 * i) / self.d_model)))
 
                if i + 1 < self.d_model:
                    # Apply cos to odd indices in the array; we add 1 to i because indices in Python start at 0.
                    positional_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((2 * i) / self.d_model)))
 
        # Convert numpy array to PyTorch tensor and return it
        return torch.from_numpy(positional_encoding).float()
 
    def forward(self, x):
        expanded_tensor = torch.unsqueeze(self.positional_encoding, 0).expand(x.size(0), -1, -1).to(device)
 
        return x.to(device) + expanded_tensor[:,:x.size(1), :]
 
class PositionalEmbedding(nn.Module):
  def __init__(self, sequence_length, vocab_size, embed_dim):
    super(PositionalEmbedding, self).__init__()
    self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
    self.position_embeddings = PositionalEncoding(embed_dim,sequence_length)
 
  def forward(self, inputs):
    embedded_tokens = self.token_embeddings(inputs).to(device)
    embedded_positions = self.position_embeddings(embedded_tokens).to(device)
    return embedded_positions.to(device)

Step 6 : Create FNet Encoder




class FNetEncoder(nn.Module):
 
  def __init__(self,embed_dim, dense_dim):
    super(FNetEncoder,self).__init__()
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.dense_proj = nn.Sequential(nn.Linear(self.embed_dim,self.dense_dim), nn.ReLU(), nn.Linear(self.dense_dim,self.embed_dim))
 
    self.layernorm_1 = nn.LayerNorm(self.embed_dim)
    self.layernorm_2 = nn.LayerNorm(self.embed_dim)
 
  def forward(self,inputs):
 
    fft_result = fft.fft2(inputs)
 
    #taking real part
    fft_real = fft_result.real.float()
 
    proj_input = self.layernorm_1 (inputs + fft_real)
    proj_output = self.dense_proj(proj_input)
    return self.layernorm_2(proj_input +proj_output)

Step 7 : Create FnetDecoder




class FNetDecoder(nn.Module):
 
  def __init__(self,embed_dim,dense_dim,num_heads):
    super(FNetDecoder,self).__init__()
    self.embed_dim = embed_dim
    self.dense_dim = dense_dim
    self.num_heads = num_heads
 
    self.attention_1 = nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)
    self.attention_2 = nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)
 
    self.dense_proj = nn.Sequential(nn.Linear(embed_dim, dense_dim),nn.ReLU(),nn.Linear(dense_dim, embed_dim))
 
    self.layernorm_1 = nn.LayerNorm(embed_dim)
    self.layernorm_2 = nn.LayerNorm(embed_dim)
    self.layernorm_3 = nn.LayerNorm(embed_dim)
 
  def forward(self, inputs, encoder_outputs, mask=None):
    causal_mask = nn.Transformer.generate_square_subsequent_mask(inputs.size(1)).to(device)
 
    attention_output_1, _ = self.attention_1(inputs, inputs, inputs, attn_mask=causal_mask)
    out_1 = self.layernorm_1(inputs + attention_output_1)
 
    if mask != None:
      attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs, key_padding_mask =torch.transpose(mask, 0, 1).to(device))
    else:
      attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs)
    out_2 = self.layernorm_2(out_1 + attention_output_2)
 
    proj_output = self.dense_proj(out_2)
    return self.layernorm_3(out_2 + proj_output)

Step 8 : Fnet Model




class FNetModel(nn.Module):
    def __init__(self, max_length, vocab_size, embed_dim, latent_dim, num_heads):
        super(FNetModel, self).__init__()
 
        self.encoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
        self.encoder1 = FNetEncoder(embed_dim, latent_dim)
        self.encoder2 = FNetEncoder(embed_dim, latent_dim)
        self.encoder3 = FNetEncoder(embed_dim, latent_dim)
        self.encoder4 = FNetEncoder(embed_dim, latent_dim)
 
 
        self.decoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
        self.decoder1 = FNetDecoder(embed_dim, latent_dim, num_heads)
        self.decoder2 = FNetDecoder(embed_dim, latent_dim, num_heads)
        self.decoder3 = FNetDecoder(embed_dim, latent_dim, num_heads)
        self.decoder4 = FNetDecoder(embed_dim, latent_dim, num_heads)
 
 
        self.dropout = nn.Dropout(0.5)
        self.dense = nn.Linear(embed_dim, vocab_size)
 
    def encoder(self,encoder_inputs):
        x_encoder = self.encoder_inputs(encoder_inputs)
        x_encoder = self.encoder1(x_encoder)
        x_encoder = self.encoder2(x_encoder)
        x_encoder = self.encoder3(x_encoder)
        x_encoder = self.encoder4(x_encoder)
        return x_encoder
 
    def decoder(self,decoder_inputs,encoder_output,att_mask):
        x_decoder = self.decoder_inputs(decoder_inputs)
        x_decoder = self.decoder1(x_decoder, encoder_output,att_mask) ## HERE for inference
        x_decoder = self.decoder2(x_decoder, encoder_output,att_mask) ## HERE for inference
        x_decoder = self.decoder3(x_decoder, encoder_output,att_mask) ## HERE for inference
        x_decoder = self.decoder4(x_decoder, encoder_output,att_mask) ## HERE for inference
        decoder_outputs = self.dense(x_decoder)
 
        return decoder_outputs
 
    def forward(self, encoder_inputs, decoder_inputs,att_mask = None):
        encoder_output = self.encoder(encoder_inputs)
        decoder_output = self.decoder(decoder_inputs,encoder_output,att_mask=None)
        return decoder_output

Step 9 : Initialize Model

We declare hyperparameters and initialize our model




# Assuming your constants are defined like this:
MAX_LENGTH = 512
VOCAB_SIZE = len(tokenizer.vocab)
EMBED_DIM = 256
LATENT_DIM = 100
NUM_HEADS = 4
 
# Create an instance of the model
fnet_model = FNetModel(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM, LATENT_DIM, NUM_HEADS).to(device)

Step 10 : Train the model




# # Define your optimizer and loss function
optimizer = torch.optim.Adam(fnet_model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)
 
epochs = 100
for epoch in range(epochs):
    train_loss = 0
    for batch in dataloader:
        encoder_inputs_tensor = batch['input_ids'][:,:-1].to(device)
        decoder_inputs_tensor = batch['input_ids'][:,1:].to(device)
 
        att_mask = batch['attention_mask'][:,:-1].to(device).to(dtype=bool)
        optimizer.zero_grad()
        outputs = fnet_model(encoder_inputs_tensor, decoder_inputs_tensor,att_mask)
        decoder_inputs_tensor.masked_fill(batch['attention_mask'][:,1:].ne(1).to(device), -100).to(device)
 
        loss = criterion(outputs.view(-1, VOCAB_SIZE), decoder_inputs_tensor.view(-1))
        train_loss = train_loss + loss.item()
        loss.backward()
        optimizer.step()
    print (f" epoch: {epoch}, train_loss : {train_loss}")

Output:

 epoch: 0, train_loss : 13.495175334392115
epoch: 1, train_loss : 0.9018354846921284
epoch: 2, train_loss : 0.3800733484386001
epoch: 3, train_loss : 0.626482578649302
epoch: 4, train_loss : 460.4480260747587

Step 11 : Use model for text generation

To perform text generation using a Transformer decoder, we can use a technique called “autoregressive decoding,” where we iteratively generate one token at a time by sampling from the model’s output distribution and feeding the sampled token back into the input for the next step. We use the encoder part of the model to generate context vector for a given input token.




MAX_LENGTH =100 # your MAX_LENGTH value
 
def decode_sentence(input_sentence, fnet_model):
    fnet_model.eval()
 
    with torch.no_grad():
        tokenized_input_sentence = torch.tensor(tokenizer(preprocess_text(input_sentence)['text'])['input_ids']).to(device)#
        tokenzied_target_sentence = torch.tensor([101]).to(device) # '[CLS]' token
        current_text = preprocess_text(input_sentence)['text']
        for i in range(MAX_LENGTH):
            predictions = fnet_model(tokenized_input_sentence[:-1].unsqueeze(0),tokenzied_target_sentence.unsqueeze(0))
            predicted_index = torch.argmax(predictions[0, -1, :]).item()
            predicted_token = tokenizer.decode(predicted_index)
            if predicted_token == "[SEP]"# Assuming [end] is the end token
              break
            current_text += " "+ predicted_token
            tokenized_target_sentence = torch.cat([tokenzied_target_sentence, torch.tensor([predicted_index]).to(device)], 0).to(device)
            tokenized_input_sentence = torch.tensor(tokenizer(current_text)['input_ids']).to(device)
        return current_text
decode_sentence({'text': 'How are you ?'}, fnet_model)

Output:

'how are you ? mort ##ries ke ke ke writing ke ##ries writing h h writing ke writing writing ke writing h h h writing h

In order to get a better output we need to train the model with large amount of data and for significant time which will require GPUs.

Conclusion

The article then delved into the implementation of FNet architecture for text generation using PyTorch in Python. The step-by-step guide covered data loading, preprocessing, tokenization, embedding positional encoding, and the creation of FNet encoder and decoder classes. A complete FNet model was constructed and trained on a dataset, demonstrating the training process and providing insights into model performance through training loss monitoring.


Article Tags :