๐๏ธโโ๏ธ Fine-Tune an LLM
How to train an LLM on your own data.
Fine-tuning is a power tool that must be handled with care. It takes a lot of compute and time to fine-tune an LLM, so it is important to have a plan for what you are trying to achieve.
To get going, you should start by seeing if you can solve your problem by pure prompting alone. Oxen.ai has great tooling for trying a prompt and a model across your dataset with itโs model inference feature. If you find that prompting alone is not giving you the results you want, now it is time to consider fine-tuning.
With that in mind, letโs start with a few reasons you might want to fine-tune an LLM on your own data.
Three Reasons to Fine-Tune
Proprietary Data ๐
If you have proprietary or private datasets that the big AI labs will never see, this is a great case for building a custom model. Fine-tuning lets you build models that understand your domain, not just everyone elseโs, giving you a competitive edge.
Save Costs ๐ฐ
Leveraging APIs like OpenAI or Anthropic can get pricey. It can be more cost effective to fine-tune a smaller model on your own data. If your task is well defined, and you have good evals, you can get a good model for a fraction of the cost per API call.
Speed Things Up โก
Similar to cost, if you have low latency or high-throughput requirements, fine-tuned models can process data faster while staying at small parameter counts. Depending on the size of the model, the weights can be downloaded and run on in a low resource environment like a CPU.
Example: Medical Question Answering
The domain of medicine is a good example where you might want to fine-tune an LLM. The domain is rich with nuance, and the data often has privacy concerns and cannot be shared publicly. If you want to follow along, you can run this example notebook in your own Oxen.ai account with the same data and model.
As you iterate and experiment with fine-tuning, Oxen.ai will help you run experiments in parallel, version your training datasets and model checkpoints, and look at the results of your experimentation.
Make sure to configure your notebook with an A10G GPU and the following dependencies. Allocate at least 2 hours, 8 cpu cores and 8GB of memory for the training to complete in a reasonable amount of time.
The Dataset
The dataset we will be using in this example is the MedQuAD dataset. MedQuAD includes 47,457 medical question-answer pairs created from 12 NIH websites (e.g. cancer.gov, niddk.nih.gov, GARD, MedlinePlus Health Topics). The collection covers 37 question types (e.g. Treatment, Diagnosis, Side Effects) associated with diseases, drugs and other medical entities such as tests.
To load the dataset, we can use the load_dataset
function from the oxen.datasets
library. This is a wrapper around the Hugging Face datasets library, and is an easy way to load datasets from the Oxen.ai hub. To have fine-tuning work well, it is a good idea to have at least ~1000-10000 unique examples in your dataset. If you can collect more, thatโs even better.
Donโt have a dataset yet? Checkout how to generate a synthetic dataset from a stronger model to bootstrap your own.
We then want to transform this dataset into a format that can be used for training a chatbot. This means mapping the question and answer pairs to a list of messages with roles. This is the format that most LLMs expect for training and inference.
The Model
For this example, we will be using the Qwen/Qwen2.5-1.5B-Instruct model. This is a 1.5B parameter model that will be quick to train, and fast for inference. You can even download the weights and run on your laptop if you want.
To load the model, we can use the AutoModelForCausalLM
and AutoTokenizer
classes from the transformers
library.
Before you start training, it is a good idea to get a feel for the model. Start by writing a function to make a prediction given a prompt and system message.
Call the predict
function with a sample question.
Once you have predictions working from a model, it is good practice to have some sort of evaluation in place to see if fine-tuning actually improved the model. For situations where precision is important, you may want to build a Human in the Loop pipeline to evaluate the modelโs predictions. If you want to automate the evaluation process, you can use an LLM as a Judge pipeline to evaluate the modelโs predictions.
To learn more about how to evaluate your model, check out Eugene Yanโs blog post on fixing your evaluation process.
Parameter Efficient Fine-Tuning
To make our fine-tuning process more efficient in terms of memory and time, we can use a technique called Parameter Efficient Fine-Tuning. This technique uses a technique called Low-Rank Adaptation (LoRA) to fine-tune the model. If you want to learn more about LoRA, check out the LoRA paper or our Arxiv Dive on the topic.
This step is optional, but is good to know if you have limited resources. If you do not use parameter efficient fine-tuning, you will need to select a larger GPU for training.
Branches for Experiments
It is rare that you will get a fine-tune perfect on the first try. You must have an experimental mindset and be willing to iterate. In this case we will be simply saving the trained models and results to new branches on the same repository. We will setup an OxenExperiment
class that will handle creating a new branch, saving the model, and logging the results.
Branches are light weight in Oxen.ai, and by default will not be downloaded to your local machine when you do a clone. This means you can easily store model weights and other large assets on parallel branches and keep your main
branch small and manageable.
When you start a training run, youโll see a new branch in the repo with a prefix, number and a timestamp.
You can navigate to this branch and look in the models
directory to see the model weights and other assets.
Logging and Saving
Once we have the experiment setup, we will want to reference it during training and log our experiment results. To do this, we will setup an OxenTrainerCallback
that will be called during training to save the model weights and our metrics. This is a subclass of the TrainerCallback
class from the transformers
library, which can be passed into our training loop.
Since we are subclassing the TrainerCallback
class, we implement the on_save
and on_log
methods. The on_save
method is called when the model is saved to disk, and the on_log
method is called when the model is trained on a batch, reporting loss and other useful metrics.
The most important concepts here are the Workspace
and DataFrame
objects from the oxenai
library. The Workspace
is a wrapper around the branch that we are currently on. This allows us to write data to the remote branch without committing the changes to the branch. Think of it like your local repo of unstaged changes, but for remote branches. To navigate to your workspaces, use the branch dropdown and then look at the active workspaces for a file.
During training it would be expensive to commit the changes to the branch every step, so instead we use a Workspace
to write the temporary results, and then can commit the changes to the branch after training is complete.
The DataFrame allows us to write rows and columns to the log file. We can read from this to make plots and analyze the results. When clicking on a data frameโs workspace, you can see a preview of the data that is written during the on_log
method.
With all the building blocks in place, we can then chain all of these classes together and specify the RemoteRepo
, model name, and output directory.
The Training Loop
The trl
library from Hugging Face is an easy to use library for training and fine-tuning models. We can use the SFTConfig
class to setup our training loop. This determines our batch size, learning rate, number of epochs, and other hyperparameters.
Once you have set up the training arguments, you can then setup the training loop. Pass in the model, training arguments, peft config, the training dataset, and callbacks.
Finally, you can then train the model.
This should take just under 2 hours with the settings above. Once the training is complete, you will be able to download the model weights from the experiment branch and use them for inference.
Evaluation
Just because the fine-tune has completed, does not mean your job is done. Now you must evaluate the model to see if it is any good. With the dataset that we have been using, it is hard to do an exact string match evaluation on outputs to tell if the fine-tuned model is better than the original.
Instead, we will use an LLM as a Judge pipeline to evaluate the modelโs predictions. This will allow us to quickly see if the fine-tuned model is better than the original.