Open In App

Fine Tuning Large Language Model (LLM)

Large Language Models (LLMs) have revolutionized the natural language processing by excelling in tasks such as text generation, translation, summarization and question answering. Despite their impressive capabilities, these models may not always be suitable for specific tasks or domains due to compatibility issues. To overcome this fine tuning is performed. Fine tuning allows the users to customize pre-trained language models for specialized tasks. This involves refining the model on a limited dataset of task-specific information, enhancing its performance in that particular task while retaining its overall language proficiency.

What is Fine Tuning?

Fine-tuning LLMs means we take a pre-trained model and further train it on a specific data set. It is a form of transfer learning where a pre-trained model trained on a large dataset is adapted to work for a specific task. The dataset required for fine-tuning is very small compared to the dataset required for pre-training.

The benefits of fine-tuning are:



Why Fine-tune?

The pre-trained model or foundation model is trained on a large amount of data using self-supervised training and it learns the statistical representation of language very well. It has learned how to be a good reasoning engine. With fine-tuning we

Types of Fine Tuning

Let us explore various types of fine-tuning methods.

Supervised fine-tuning

Instruction fine-tuning

PEFT methods

Training a full model is generally challenging. A model with 1 B parameters will generally take 12-15 times memory for training. During training, we need extra memory for gradients, optimizer stares, activation, and temp memory for variables. Hence the maximum size of a model that can be fit on a 16 GB memory is 1 billion. Model beyond this size needs higher memory resulting in high compute cost and other training challenges.

To efficiently train large models on small compute resources we have PEFT methods which stand for Parameter efficient fine tuning. This method does not update all the weights of the model thereby reducing the memory requirements significantly. PEFT can further be classified as

1. Selective Method

In the selective method, we freeze most of the model’s layers and unfreeze only selective layers. We train and modify the weights of this selective layer to adapt to our specific task. This method is generally not used.

2. Reparameterization Method

This is the most common method. It reparameterizes model weights with low-rank matrices. This is also known as LoRA (Low-RAnk matrices). We keep the model weights frozen. Instead, we inject the small new trainable parameters with low-dimension matrices.

Example

QLoRA – It’s a further extension of the LoRA method. Here we further optimize memory requirements by quantizing our weights. Normally we use 32 bytes for storing model weights and other parameters while model training. Using quantizing methods we can use 16 bytes for storing model weight and parameters. This results in loss of precision but considerably reduces the memory.

3. Additive Method

Adaptive method – In the adaptive method we add new layers either in the encoder or decoder side of the model and train this new layer for our specific task.

Soft prompting – There is also a method of soft prompting or prompt tuning where we add new trainable tokens to the model prompt. These new tokens are trained while all other tokens and model weights are kept frozen. Only the newly added tokens are trained.

RLHF

RLHF stands for Reinforcement Learning Human Feedback. It is used to align a model to generate output that is preferred for human consumption.

RLHF is generally used after fine-tuning. It takes a fine-tuned model and aligns its output concerning human preference. The RLHF method uses the concept of reinforcement learning to align the model.

RLHF has below steps

Prompt Engineering vs RAG vs Fine tuning.

Let us explore the difference between prompt engineering, RAG, and fine-tuning.

Criteria

Prompt Engineering

RAG

Fine-Tuning

Purpose

Prompt engineering focuses on how to write an effective prompt that can maximize the generation of an optimized output for a given task.

The purpose of RAG is to relevant information for a given prompt from an external database.

Fine-tuning focuses on training and adapting a model for a specific task.

Model

Model weights are not updated. It focuses on building an effective prompt.

Model weights are not updated. It focuses on building context for a given prompt.

Model weights are updated

Complexity

No technical knowledge required

Compared to fine-tuning it is less complex as it requires skills related to vector databases and retrieval mechanisms only

Technical knowledge required

Compute Cost

Very less cost. Only costs related to API calls

Cost-effective compared to fine-tuning.

We may need specialized hardware to train the model depending on model size and dataset size

Knowledge

The model does not learn new data

The prompt is equipped with new data in the form of context

The model learns new data

When to use fine-tuning?

When we build an LLM application the first step is to select an appropriate pre-trained or foundation model suitable for our use case. Once the base model is selected we should try prompt engineering to quickly see whether the model fits our use case realistically or not and evaluate the performance of the base model on our use case.

In case with prompt engineering we are not able to achieve a reasonable level of performance we should proceed with fine-tuning. Fine-tuning should be done when we want the model to specialize for a particular task or set of tasks and have a labeled unbiased diverse dataset available. It is also advisable to do fine-tuning for domain-specific adoption like learning medical law or finance language.

How is fine-tuning performed?

There is no standard way of fine-tuning as there are many methods and the fine-tuning steps depend on the task objective at hand. However, it can be generalized to have the below steps:

  1. Base Model – Select the base pre-trained model that is suitable according to our task and fits in our compute budget.
  2. Method – Select the fine-tuning method suitable for the use case considering the compute budget, dataset, and model size.
  3. Prepare Dataset- We need to prepare the dataset as per the requirement of the task, fine-tune the method selected, and consider the expected input and output format of the selected base model.
  4. Training – There are many libraries available for training the model. The pytorch / tensorflow provides the lowest level of abstraction. Besides we have libraries from the transformer which provides a high level of abstraction compared to pytorch/tensorflow. Also, many libraries are being developed to provide a better level of abstraction like Lamini.
  5. Evaluate and iterate – Evaluate the model based on evaluation criteria. Iterate if required.

Fine Tuning Large Language Model Implementation

Let us fine tune a model using PEFT LoRa Method. We will use flan-t5-base model and DialogSum database. Flan-T5 is the instruction fine-tuned version of T5 release by Google. DialogSum is a large-scale dialogue summarization dataset, consisting of 13,460 (Plus 100 holdout data for topic generation) dialogues with corresponding manually labeled summaries and topics.

1. Install necessary libraries

!pip install datasets
!pip install transformers
!pip install evaluate
!pip install accelerate -U
!pip install transformers[torch]
!pip install peft

2. Import the libraries




from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer,GenerationConfig
import torch
device ='cuda' if torch.cuda.is_available() else 'cpu'
 
import evaluate
 
import pandas as pd
import numpy as np

3. Load Dataset and Model from hugging face




huggingface_dataset_name = "knkarthick/dialogsum"
dataset =load_dataset(huggingface_dataset_name)
 
model_name = "google/flan-t5-base"
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

4. Define a function to check number of model parameters

The below defined function provides the size and trainability of the model’s parameters, which will be utilized during PEFT training to see how it reduces resource requirements.




def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"
 
 
print(print_number_of_trainable_model_parameters(base_model))

Output:

trainable model parameters: 247577856
all model parameters: 247577856
percentage of trainable model parameters: 100.00%

5. Base model output

Let us check a random sample from test dataset and generate its output. Before generating the output, we prepare a simple prompt template as shown below.




i= 20
dialogue = dataset['test'][i]['dialogue']
summary = dataset['test'][i]['summary']
 
 
prompt = f"Summarize the following dialogue  {dialogue}  Summary:"
 
 
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = tokenizer.decode(base_model.generate(input_ids, max_new_tokens=200)[0],skip_special_tokens=True)
 
 
print(f"Input Prompt : {prompt}")
print("--------------------------------------------------------------------")
print("Human evaluated summary ---->")
print(summary)
print("---------------------------------------------------------------------")
print("Baseline model generated summary : ---->")
print(output)

Output:

Input Prompt : Summarize the following dialogue  #Person1#: What's wrong with you? Why are you scratching so much?
#Person2#: I feel itchy! I can't stand it anymore! I think I may be coming down with something. I feel lightheaded and weak.
#Person1#: Let me have a look. Whoa! Get away from me!
#Person2#: What's wrong?
#Person1#: I think you have chicken pox! You are contagious! Get away! Don't breathe on me!
#Person2#: Maybe it's just a rash or an allergy! We can't be sure until I see a doctor.
#Person1#: Well in the meantime you are a biohazard! I didn't get it when I was a kid and I've heard that you can even die if you get it as an adult!
#Person2#: Are you serious? You always blow things out of proportion. In any case, I think I'll go take an oatmeal bath. Summary:
--------------------------------------------------------------------
Human evaluated summary ---->
#Person1# thinks #Person2# has chicken pox and warns #Person2# about the possible hazards but #Person2# thinks it will be fine.
---------------------------------------------------------------------
Baseline model generated summary : ---->
Person1 is scratching so much that he can't stand it anymore.

7. Define our dataset




def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids
     
    return example
 
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary',])
 
tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 100 == 0, with_indices=True)

6. Define lora config, Peft model , training arguments and peft trianiger

Let us use a low rank matrix of size 32. We see that compared to model size we need to train only 1.41 % of parameters.




from peft import LoraConfig, get_peft_model, TaskType
 
 
lora_config = LoraConfig(r=32,lora_alpha = 32, target_modules=["q","v"],
                         lora_dropout = 0.5, bias ="none", task_type  =TaskType.SEQ_2_SEQ_LM)
 
output_dir = f"./peft-dialogue-summary-training"
 
peft_model_train = get_peft_model(base_model, lora_config)
print(print_number_of_trainable_model_parameters(peft_model_train))

Output:

trainable model parameters: 3538944
all model parameters: 251116800
percentage of trainable model parameters: 1.41%

Let us define our training parameters and training for above peft model




peft_training_args = TrainingArguments(
     output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3, # Higher learning rate than full fine-tuning.
    num_train_epochs=5,
 )
peft_trainer = Trainer(
    model=peft_model_train,
    args=peft_training_args,
    train_dataset=tokenized_datasets["train"],
)
 
peft_trainer.train()

Output:

TrainOutput(global_step=160, training_loss=3.6751495361328126, metrics={'train_runtime': 155.9012, 'train_samples_per_second': 4.009, 'train_steps_per_second': 1.026, 'total_flos': 434768117760000.0, 'train_loss': 3.6751495361328126, 'epoch': 5.0})

8. Save our model and load it for inference




peft_model_path="./peft-dialogue-summary-checkpoint-local"
 
peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)
 
 
from peft import PeftModel, PeftConfig
 
peft_model_base = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
 
peft_model = PeftModel.from_pretrained(peft_model_base,
                                       './peft-dialogue-summary-checkpoint-local',
                                                                             is_trainable=False)

9. Generate output




peft_model_outputs = peft_model.generate(input_ids=input_ids, max_new_tokens=200)
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)
 
 
print(f"Input Prompt : {prompt}")
print("--------------------------------------------------------------------")
print("Human evaluated summary ---->")
print(summary)
print("---------------------------------------------------------------------")
print("Baseline model generated summary : ---->")
print(output)
print("---------------------------------------------------------------------")
print("Peft model generated summary : ---->")
print(peft_model_text_output)

Input Prompt : Summarize the following dialogue  #Person1#: What's wrong with you? Why are you scratching so much?
#Person2#: I feel itchy! I can't stand it anymore! I think I may be coming down with something. I feel lightheaded and weak.
#Person1#: Let me have a look. Whoa! Get away from me!
#Person2#: What's wrong?
#Person1#: I think you have chicken pox! You are contagious! Get away! Don't breathe on me!
#Person2#: Maybe it's just a rash or an allergy! We can't be sure until I see a doctor.
#Person1#: Well in the meantime you are a biohazard! I didn't get it when I was a kid and I've heard that you can even die if you get it as an adult!
#Person2#: Are you serious? You always blow things out of proportion. In any case, I think I'll go take an oatmeal bath. Summary:
--------------------------------------------------------------------
Human evaluated summary ---->
#Person1# thinks #Person2# has chicken pox and warns #Person2# about the possible hazards but #Person2# thinks it will be fine.
---------------------------------------------------------------------
Baseline model generated summary : ---->
Person1 is scratching so much that he can't stand it anymore.
---------------------------------------------------------------------
Peft model generated summary : ---->
#Person1# thinks he may be coming down with chicken pox. #Person2# thinks he has chicken pox.

We see there is a improvement in output. WIth further training and increasing our dataset we can achieve a better performance

Conclusion

In this article, we got an overview of various fine-tuning methods available, the benefits of fine-tuning, evaluation criteria for fine-tuning, and how fine-tuning is generally performed. We then saw python implementation of LoRa training.


Article Tags :