Fine-tuning is a technique used in machine learning, particularly in natural language processing (NLP), where a pre-trained model is further trained on a specific task or dataset. Fine-tuning a pre-trained model is typically faster than training a model from scratch. This is because the model has already learned valuable features and representations from a large and diverse dataset during pre-training. Fine-tuning often leads to better performance on specific tasks, for example, you can take a pre-trained model trained on a general text corpus and fine-tune it for a domain-specific task, such as medical text analysis or legal document classification. Furthermore, you can adjust the model's behavior and adapt it to the specific needs of your target task. It also helps with costs and prompt limit problem: Once a model has been fine-tuned by training on many more examples than can fit in the prompt, you won't need to provide as many examples in the prompt. This saves costs and enables lower-latency requests.
However, it's important to note that fine-tuning also comes with its own set of limitations and considerations. In practice, fine-tuning requires training a model, with the key distinction being that the weights already contain valuable information. Consequently, fine-tuning shares some of the same needs and challenges as training from scratch. These include the need for task-specific labeled data, the possibility of biases in the pre-trained model, and the risk of overfitting when dealing with small target datasets. Additionally, there are associated costs related to processing requirements. Most large language models (LLMs) demand significant GPU resources for fine-tuning. However, through techniques like Quantization and LoRA and Python code, it is possible to reduce these requirements to the point where fine-tuning the LLama2 7b parameter model can be accomplished with just a single GPU. For more information about this techniques, check this other article.
Code example
Here is a code example demonstrating how to fine-tune the Llama2-7b-chat model in Python. We will utilize the BitsAndBytes library for quantization, PEFT for LoRA to fine-tune specific model parameters, and Transformers to load and train the model. With the help of quantization and LoRA, we were able to train the model using only an ml.g5.xlarge instance of AWS SageMaker, which has only 1 GPU and 24GB of GPU Memory (for the full model we would normally need several GPUs with at least 112 GB of memory in total).
The goal of this fine-tuning process is to enable the model to answer questions about the company Rootstrap without requiring additional context in the prompt. To achieve this, we incorporate Rootstrap-related questions and answers into the training dataset. Additionally, we include unrelated questions in the dataset, all with the same answer: "I am only limited to answer Rootstrap information." This approach ensures that the model specializes in responding to Rootstrap-related queries. To perform the fine tuneLet's go through the code step by step:
Import dependencies
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=dependencies.py
Load the Rootstrap’s info dataset
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=load_dataset.py
Let’s take a look at the 5 firsts elements of the dataset
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=print_dataset.py
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=dataset.json
Create bitsandbytes config for 4bits quantization and load the model with the tokenizer that will transform the input text to numeric form:
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=load_model.py
Let’s try the vanilla model without fine-tuning, we will create an inference function for that:
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=vanilla_inference.py
Surprisignly, Llama2 already knows what Rootstrap is, but do not know the projects.
It’s time for fine tuning! First, preprocess the dataset.
Creation of the preprocessing functions:
- Formatting the dataset with the right prompt structure
- Shuffle and removal of unused columns
- Tokenization of the texts inputs to convert to numeric inputs.
- Filter out inputs which are bigger than max input allowed
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=preprocessing.py
Apply the functions to the dataset:
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=apply_preprocessing.py
We have our dataset prepared. Now, let’s create the function that will help us with the LoRA configuration, we will use peft library for that.
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=lora_config.py
We can finally create and execute our training code:
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=train.py
The model was saved in the output_dir.
We can now load it and make it questions about Rootstrap to see if it has improved with fine-tuning:
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=fine_tuned_inference.py
Not ideal (Rootstrap has never worked with Uber), but it has learnt something indeed. Maybe with more training, or with less quantization, better results can be achieved.
Let’s verify that the model refuses to answer questions not related to Rootstrap:
CODE: https://gist.github.com/santit96/2db79e1b2dc3a1fb9de03d94c4ab0552.js?file=fine_tuned_inference_2.py
A little verbose, but refuses correctly to answer the question.
Now we can use this fine-tuned model, for example, to create a simple python Q&A chatbot that answers Rootstrap’s information. Although it still may hallucinate a little without context, we don't have to pass all the company context in every question, since it already has some of the information embedded and has learned not to answer questions unrelated to Rootstrap.