Open In App

Wav2Vec2: Self-A Supervised Learning Technique for Speech Representations

Last Updated : 17 Dec, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

In the ever-evolving landscape of artificial intelligence, the quest for efficient and versatile models has led researchers to explore innovative training paradigms. Among these, self-supervised learning has emerged as a frontrunner, offering a promising solution to the perennial challenge of acquiring labelled data for diverse tasks. One remarkable stride in this direction comes with Wav2Vec2, a groundbreaking model designed for self-supervised speech representation learning.

What is the Wav2Vec2 Model?

Wav2Vec2 stands as a testament to the transformative potential of self-supervised training, particularly in the realm of Natural Language Processing (NLP). Its architecture is tailored to harness vast amounts of unlabeled speech data, distilling intricate patterns and nuances to create a rich and generalized understanding of spoken language. This self-supervised pre-training phase sets the stage for subsequent fine-tuning, where the model refines its knowledge on specific downstream tasks using limited labeled datasets.

  1. Pre-Training: Here the model is trained to develop broad general representations by learning from extensive datasets that are unlabeled. This step is believed to enhance performance on a downstream task with constrained data.
  2. Fine Tuning : The pre trained model is then fine tuned using a small labelled dataset. The fine tuning can be done for a variety of downstream task

It means is that the model is a general model that has been trained to learn a discretized representation of speech audio. It is trained on a large amount of unlabeled data to learn to represent the raw audio data as a discretized vector space encoding. This discretized vector space is thought of as speech units.

Why discretized?

Since speech signal is continuous the focus is to make it discretized so that the many of the transformer architecture which have been developed in the domain of text processing such as BERT which take discretized inputs can be utilized for further processing. The BERT architecture takes embeddings as input to predict the next embeddings. This embeddings are discretized based on the words input. This step is called pre-training.

After the pre-training stage, the model can be fine-tuned for various downstream tasks using a very small amount of labeled data. In the paper, the model was fine-tuned on small labeled data with CTC loss for the ASR(automatic speech recognition) task .

Let us understand the architecture and training process of Wav2Vec2 model

Architecture of Wav2Vec2 Model

Wav2Vec2 model-Geeksforgeeks

Wav2Vec2 Model

Let’s take a closer look at each component.

Feature encoder

The input to the feature encoder is a sound waveform sampled at 16khz. The feature encoder has seven blocks, and each block’s temporal convolutions have 512 channels with strides (5,2,2,2,2,2,2) and kernel widths (10,3,3,3,3,2,2). This yields encoder output frequency of 49Hz with a stride of around 20ms between each sample and a receptive field of 400 input samples or 25ms of audio. Detailed calculation is shown below.

Layer

Channel * Input Dimension

Kernel/Filter Width

Strides

Channel*outputDimension

Total Stride = 16000/49 = approx 320 samples

16000 hz = 1 sec. Therefore 320 samples = 20 ms

1

1 x 16000

10

5

512 x 3199

2

512 x 3199

3

2

512 x 1599

3

512 x 1599

3

2

512 x 799

4

512 x 799

3

2

512 x 399

5

512 x 399

3

2

512 x 199

6

512 x 199

2

2

512 x 99

7

512 x 99

2

2

512 x 49

Feature-encoder

Feature Encoder of Wav2Vec2

Contextualized representations with Transformers

The core of wav2vec 2.0 is its Transformer encoder, which takes as input the latent feature vectors obtained from the feature encoder and processes it through transformer blocks. The input sequence undergoes an initial transformation by passing through a feature projection layer, which increases the dimension from 512 (the feature encoder output) to 768 for the BASE variant or 1,024 for the LARGE variant thereby aligning with the inner dimension requirements of the Transformer encoder.

BASE contains 12 transformer blocks, model dimension 768, inner dimension (FFN) 3,072 and 8 attention heads. The LARGE model is made up of 24 transformer blocks with model dimensions of 1,024, inner dimensions of 4,096 and 16 attention heads.

One difference with respect to BERT architecture is how positional information is incorporated. Instead of fixed positional embeddings which encode absolute positional information, the wav2vec model instead uses a new grouped convolution layer to learn relative positional embeddings by itself.

The output of transformer is a context vector. The transformer builds context representations over continuous speech representations which are compared with respect to the output of quantization module . The output of quantization module (quantized vector) represent the discrete targets to be learnt by the transformer encoder. Here both the quantized vector and context vector are jointly learn using contrastive loss . More details about this in the training section.

Quantization module

The quantization module of Wav2Vec2 is adopted from vq-wav2vec architecture. Below diagram shows the overall quantization process.

Architecture of Wav2Vec2 Model

Wav2Vec2 Quantization Process

The output of the feature encoder (rather than the context transformer) is discretized in parallel using a product quantization-based quantization module. Quantization is the process of mapping infinite values to discrete ones. A codebook in product quantization is like a set of representative points that help us to discretize . This representative values can be thought of as speech units. Here are the steps of quantization

  1. For time duration of 1 sec we get 512*49 dimension vector from the feature encoder. Thus we get 49 latent features each of size 512 .
  2. A linear layer projects each of the feature from 512 to 640(V) logits. Here the 640 logits is divided into two groups (G=2). This 320 logits represent codebook of 320 discrete vectors. The codebook is randomly initialized. The codebook representation is learnt during training using contrastive loss. Since we have mapped our feature vector into two groups we get a total possible combination of 320 * 320 =102400 speech units
  3. Using Gumbel-Softmax a one hot vector is produced for each group G. Thus we get two one hot vector . Each of the one hot vector corresponds to one of the 320 discrete vectors in the codebook.
  4. Gumbel Softmax is a popular technique for sampling from discrete space. The method involves introducing stochasticity (using Gumbel distribution) into the discrete decision-making process by using a differentiable approximation(softmax) to the argmax operation. It enables to backpropagate through random samples of discrete variables. Gumbel-Max Trick is very similar to the Reparameterization track whereby we are combining the deterministic part (the model logits) with the stochastic part (Gumbel noise ). During forward pass or inference the largest index is picked and the vector corresponding to it from the codebook is used. During backward pass the logits calculated is used for backpropagation.
  5. Each of the vector in code book is of size d/2 . We obtain two code book vectors(e1 and e2) for each latent feature vector (Z). This vector e1 and e2 are concatenated to get a ‘d’ dimension vector. Then it is passed through a linear transformation Rd→ Rf to obtain quantized vector q ∈ Rf. This transformation is done to match the dimension of transformer output.

Training Process

First let us understand what is a contrastive score and contrastive loss in order to understand the training procedure of wav2vec model

Contrastive Score typically involves computing a similarity metric between pairs of samples. Commonly used similarity metrics include cosine similarity or dot product. The idea is to compare the representations of two instances in the embedding space. For positive pairs (examples that should be similar), the contrastive score should be high, indicating high similarity. For negative pairs (examples that should be dissimilar), the contrastive score should be low, indicating low similarity.

Contrastive Loss is often used as part of a loss function during training. One popular loss function in contrastive learning is the contrastive loss, which encourages the model to bring positive pairs closer together in the embedding space while pushing negative pairs apart.

L(i,j) = -log(\frac{e^{(sim(z_i,z_j/\tau)}}{\Sigma e^{(sim(z_i,z_k/\tau)}})

where

  • L(i,j)         is the contrastive loss for samples i and j
  • sim(z_i , z_j)        is the similarity score between samples i and j
  • the sum is over all samples in the batch

Here one important thing to note is that the positive pair are moved/changed to make them more similar and negative pairs are moved/changed in a way that make them more dissimilar.

In the context of Wav2Vec2

  1. The output from feature encoder is passed through quantization module to get a quantized representation from the codebook. This is the positive sample .
  2. The same output form feature encoder is passed through transformer encoder . Before passing a proportion of the feature is masked (~50%). The objective is to learn the representation of discrete speech audio at the masked position by comparing it with true quantized latent speech representation. For each masked position, 100 negative distractors(negative sample) are uniformly sampled from other positions in the same sentence. This 100 negative distractors are from codebook of 320 representations excluding the positive vector.
  3. The model compares the similarity using the conservative loss equation as shown above.
  4. The loss is then backpropagated through the transformer as well as the quantization module to make the output of transformer encoder and the codebook positive sample similar as well as codebook negative sample more dissimilar.

Diversity Loss is used to encourage the equal use of all the entries in codebooks to represent both positive and negative samples during training a diversity loss is added . This works by maximizing the entropy of the averaged-Softmax distribution, preventing the model to always choose from a small sub-group of all available codebook entries.

Wav2Vec2 Model Implementation

Install Libraries

Install the below libraries if not available in your environment. These are required to run the subsequent code.

!pip install datasets
!pip install transformers
!pip install torch
!pip install evaluate
!pip install transformers[torch]

Import Libraries

And then import the libraries into your notebook. Required libraries include numpy, transformers and pytorch.

Python3

# Imports required
import numpy as np
from datasets import load_dataset, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import evaluate
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from transformers import TrainingArguments, Trainer

                    

Loading Dataset and Preprocessing

Loading Minds 14 dataset and split the dataset in 80:20 ratio.

Python3

# Load the PolyAI dataset.
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train[:80]")
 
# Remove unnecessary columns
dataset = dataset.remove_columns(['path','english_transcription','intent_class'])
 
# Split the dataset into train and test
dataset = dataset.train_test_split(test_size = 0.2, shuffle=False)

                    

Resampling data

We need to resample the data to 16khz as the Wav2Vec2 model is trained in 16khz and the dataset is in 8khz. For this we will use Audio library.

Python3

# Declare device variable
device = 'cuda' if torch.cuda.is_available() else'cpu'
 
# Resample the dataset to 16 Khz as MCTCT model is trained on 16khz
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
 
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(device)

                    

Drawing Inferences

We format an input and use the base model to infer its transcription. The model produces output in logits, and we decode it by selecting the maximum value among the logits. The use of ‘torch.no_grad()’ ensures that these operations do not contribute to gradient computation, which is particularly helpful when there’s no need to update the model weights.

Python3

# Lets process the first example of train dataset
 
inputs = processor(dataset['train'][3]["audio"]["array"], sampling_rate=16000, return_tensors="pt")
 
# getting the predictions
 
with torch.no_grad():
    logits = model(**inputs).logits
 
 
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
transcription

                    

Output:

['HOW DO I FURN A JOINA COUT']

The actual text of the audio is ‘how do I start a joint account’ .

Fine Tuning the model

We want to prepare our data to match the expected format for the Wav2Vec2 model using the Dataset map function. For this, we’re creating two columns named ‘input_values,’ where the raw input sound wave array needs to be resampled to 16kHz, and ‘labels,’ which will hold the transcription in the format expected by the tokenizer. To achieve this, we’re passing each piece of data through a processor defined below.

Python3

# Preparing a function to process the entire dataset
# We need to crate two variables with name 'input_featrues'
# (input array of sound wave in raw foram) and 'labels'(transcription)
 
 
def prepare_dataset(batch):
 
    audio = batch["audio"]
 
    batch["input_values"] = processor(
        audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcription"].upper()).input_ids
 
    return batch
 
 
encoded_dataset = dataset.map(prepare_dataset, num_proc=1)

                    


Creating a specialized class for Data

We’re crafting a DataCollator Class specifically designed for fine-tuning Wav2Vec2. Unlike Transformer models, ASR tasks don’t have a built-in data collator. So, we’re tweaking the DataCollatorWithPadding class to create batches of examples that match the elements found in the training or evaluation datasets.

It’s worth highlighting that ‘input_values’ and ‘labels’ need different padding strategies since they can have varying lengths. In ASR tasks with potentially large input sizes, it’s more efficient to dynamically pad training batches. This means each training sample only gets padded to match the length of the longest sample within its batch, rather than padding to the overall longest sample.

So, in essence, for fine-tuning Wav2Vec2, we’re crafting a specialized padding data collator, and we’ll define it below:


Python3

@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = "longest"
 
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_values = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
 
        batch = self.processor.pad(input_values, padding=self.padding, return_tensors="pt")
 
 
        with self.processor.as_target_processor():
 
          labels_batch = self.processor.pad(label_features, padding=self.padding, return_tensors="pt")
 
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
 
        batch["labels"] = labels
 
        return batch
       
  
data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest")

                    


Evaluation Metric

For our task, we’ll be using the word error rate metric. To measure this, we need to define a ‘compute_metrics’ function. Each logit vector has a length equal to the configured vocabulary size, which is noted as ‘config.vocab_size.’ Our main focus is on figuring out the model’s prediction, and we do this by calculating the argmax(…) of the logits.

To make sense of the predictions, we convert the encoded labels back into their original string form. This involves a couple of steps. First, we replace instances of -100 with the ‘pad_token_id.’ Then, we decode the IDs while making sure that consecutive tokens are not incorrectly grouped together. This decoding process aligns with the CTC (Connectionist Temporal Classification) style, ensuring accuracy in the representation of the original string.


Python3

wer = evaluate.load('wer')
 
def compute_metrics(pred):
    wer = evaluate.load("wer")
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
 
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
 
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
 
    wer = wer.compute(predictions=pred_str, references=label_str)
 
    return {"wer": wer}

                    


Model Training

Wav2Vec2 is a sizable model that demands a significant amount of memory, making GPU training a necessity. If your system lacks sufficient memory, there’s a risk of encountering out-of-memory issues. The learning rate has been fine-tuned through heuristic methods to ensure stable fine-tuning. It’s crucial to note that these parameters are highly dependent on the dataset, so experimenting with various values is essential.

To initiate the training process, pass these training arguments, along with the dataset, model, tokenizer, and data collator, to the Trainer. Once set up, call the ‘.train()’ method to kickstart the training.


Python3

del model
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base-960h",
  ctc_loss_reduction="mean",
  pad_token_id=processor.tokenizer.pad_token_id)
model.to(device)
# defining training arguments and trainer
training_args = TrainingArguments(
    output_dir="wav2vec2_finetuned",
    gradient_checkpointing=True,
    per_device_train_batch_size=1,
    learning_rate=1e-5,
    warmup_steps=2,
    max_steps=2000,
    fp16=True,
    optim='adafactor',
    group_by_length=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=1,
    eval_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
)
 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
 
trainer.train()

                    

Output:

Step    Training Loss    Validation Loss    Wer
100 No log 1.422148 0.354839
200 No log 1.584326 0.379032
300 No log 1.595137 0.346774
400 No log 1.534755 0.314516
500 1.022900 1.548012 0.322581
600 1.022900 1.525821 0.322581

Getting Prediction from the Fine-tuned model

Python3

## getting test data
i2 = processor(dataset['test'][6]["audio"]["array"], sampling_rate=16000, return_tensors="pt")
print(f"The input test audio is: {dataset['test'][6]['transcription']}")
 
# prediction for test data
with torch.no_grad():
    logits = model(**i2.to(device)).logits
 
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
print(f'The output prediction is : {transcription[0]}')

                    

Output :

The input test audio is: so you spent the money I'd like to see my new account balance
The output prediction is : SO JUS SPEND SOME MONEY I'D LIKE TO SEE MY NEW ACCOUNT BALANCE

The output is better this time.

Conclusion

Self-supervised learning, exemplified by models like Wav2Vec2, offers a robust approach for representation learning in domains with limited labeled data. Fine-tuning on specific tasks further refines the model’s performance, showcasing the adaptability and effectiveness of this training methodology.



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads