Post training methods in LLM using RL
Tags : Reinforcement Learning, RL , Machine learning , deep learning , ML , DL , Post traning
Maths
Reinforcement learning, here the agent takes / decides some action to take based on the current state and other variables present at timestep t, and then its takes that action and a reward is followed and weights are updated based on the rewards received by model
Consider this basic hello world
example of RL
State : Any place / position where the agent can be
Action : Up , down , left , right these are the action the agent can take
Values : The best place we can go from the current state that maximises the return is called Values, we need to find the value function, like there its (Cell Value = Nearby best Value - 1)
Policy : Given the current state and the neighbouring values , we can create a policy function , a good policy function is max(up, down, left, right)
and whichever number gives you the best answer that leads to this policy.
So this is the policy , followung these arrows we can reach the best / optimal position
So we need a Value function and a Policy function to navigate the landscape and solve the RL problem successfully.
But the catch is with increasing no. of states , finding the Policy and value function is very difficult
So we use Neural Nets for approximating both the policy and value models
VALUE NEURAL NETWORK
Value model is deterministic and PNN is approximated
The goal is to find the best direction / way to reach to next state given the current state
Training the Value Neural nets:
L_value ( theta ) = (model_pnn - actual_value )^2
POLICY NEURAL NETWORK
Policy model tells model where to take the next step to go from the current state and increases the probability in the current direction
Training a policy neural net, we are taking scenarios and seeing the gain if the Value network predicted that in the right direction the gain value is 10, means the actual Value is more compared to the predicted value then its means the Value model should increase the value there and the policy model should increase the probabily of going in that direction
Loss_policy_theta = (PNN_t(a|s) / PNN_t-1(a|s)) * momentum in that direction
To avoid very large updates, we clip it so the step is not very erractic and its slowly
Its same as we have in optimisers, that is, we need to go to a particular direction
That momentum is called advantage (A) / reward in that direction !!
Post Training
Stage 1 : SFT ( Supervised Finetuning )
Stage 2 : RLHF ( reinforcement learning with human feedback )
SFT
Supervised finetuning, (SFT) We have the supervised dataset, that is , the training dataset contains both input and defined output to train the model on and the model learns based on the samples provided like classic supervised training
Dataset:
Input Examples
Output Examples
PPO
Value and policy model both get rained at the same time
The PPO model looks like this :
This is the policy model, this is same as the GPT model with an Linear head at last , that output 1d output that tells how good or bad the model is and its just a combination of a linear layer
The walkhrough for the whole is like this : The KL divergence is to ensure the model doesnt drift away too much in its probabilities and starts giving irrelevant outputs that are not at all required to the user !!
So to regularise the probabiility we find KL div from a trained SFT model ..
GRPO
In grpo , we ditch the value model and for updating the post training params we sample multiple outputs from the model , then rerank based on the reward model, then find mean , then find the update signal / advantage
Post Training (RLHF) methods in LLM Training
Course : https://learn.deeplearning.ai/courses/post-training-of-llms/lesson/ynmgf/introduction-to-post-training
The output from the base model is not at all coherent, it doesnt output relevant information at all as it doesnt even understands or has seen the chat-template format
=== Base Model (Before SFT) Output ===
Model Input 1:
Give me an 1-sentence introduction of LLM.
Model Output 1:
⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ ⚙ �
Model Input 2:
Calculate 1+1-1
Model Output 2:
⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ �
Model Input 3:
What's the difference between thread and process?
Model Output 3:
⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ �
SFT cant be ignored in any case that is very important , its the one that tells model what to output , how to output , and in which template to output .. etc
Rest of the optimization comes as an additional improvement step, leading to better models outputs and stuff !
DPO : good for identity, multilingual , Instruction following and Safety
Improves model capabilities because of its constrastive nature ..
DPO dataset curation starts with the one where we start from the sft template and then move ahead and perform computations
So its
Fine-tuning : SFT
Optimization : DPO
Online RL : PPO and GRPO techniques