[Paper Review] Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes
Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes”, ACL 2023. All figures are captured from the paper.
TL;DR
- Distilling step-by-step is a training mechanism to extract rationales from LLMs, serving as a supervision in training small task-specific models.
- This enables deployment without LLMs.
- It can outperfrom the original LLMs with smaller training datasets and smaller models.
Knowledge Distillation
Knowledge distillation (KD) is a ML technique that transfers knowledge from a larger model ( teacher model) to a smaller model (student model), thus improving the performance of the smaller model. How? The teacher model generates noisy labels, or pseudo labels and the student model is trained to learn these pseudo labels. This enables training without a large labeled dataset, which is expensive. In NLP, there is a suitable teacher model: the LLM!
LLMs not only generate answers: Chain-of-Thought (CoT) prompting is utilized to generate rationales from LLMs. Smaller models then learn from both the teacher’s answers and rationales. KD has become increasingly popular in NLP, which is why I decided to post a KD paper.
Distilling step-by-step
As seen below, distilling step-by-step is a straightforward training mechanism.
Step 1. Considering teacher models, CoT is utilized to generate both answers and rationales.
Step 2. Considering Student models, multi-task learning is applied to minimize both label loss and rationale loss.
The figure below describes step 1. The LLM is prompted by the few-shot CoT technique and to generate rationales and answers for a given question. Step 1 is clear, so let’s take a closer look at step 2.
In general, the student model $\mathcal{f}$ is trained to minimize the prediction loss of pseudo label:
\[\mathcal{L_{label}} = \frac{1}{N} \sum\limits_{i=1}^{n}(\ell(\mathcal{f}(\mathcal{x_i}), \hat{y_i})),\]where $\ell$ represents the cross-entropy between the predicted and target tokens, and $\hat{y_i}$ represents the pseudo label generated by the teacher model.
Although we have primarily discussed knowledge distillation thus far, distilling step-by-step can also be applied in a finetuning setting. In other words, by substituting $\hat{y_i}$ with the standard label $y_i$, distilling step-by-step becomes feasible in standard finetuning as well.
Through multi-task training, the student model $\mathcal{f}$ is trained to minimize $\mathcal{L}=\mathcal{L_{label}}+\lambda\mathcal{L_{rationale}}$, where $\mathcal{L_{label}}$ is the label prediction loss and $\mathcal{L_{rationale}}$ is the rationale generation loss:
\[\mathcal{L_{rationale}} = \frac{1}{N} \sum\limits_{i=1}^{n}(\ell(\mathcal{f}(\mathcal{x_i}), \hat{r_i})).\]By learning the teacher’s rationales through $\mathcal{L_{rationale}}$, the student model learns the intermediate reasoning steps. Because of this aspect, the paper’s title, ‘Distilling step-by-step’ seems fitting.
Experiments
This is only a small portion of the experiments conducted in the paper. For the full experimental results, please refer to the original paper.
Settings:
- LLM: 540B PaLM model
- Downstream task models: T5 models
- Tasks: Natural language inference, commonsense question answering, and arithmetic math word problems
Outperforming LLMs using minimum model size and least training data
The experiment demonstrates the training data size (x-axis) and model size (indicated by the size of the shaded area) needed to surpass the LLM (green dotted line).
The upper figure represents the result of standard finetuning setting, i.e., standard label $y_i$ is used instead of pseudo label $\hat{y_i}$ in $\mathcal{L_{label}}$. The lower figure represents the standard distillation setting, i.e., pseudo label $\hat{y_i}$ is used in $\mathcal{L_{label}}$.
From the above figures, we can identify two facts.
Distilling step-by-step consistently outperforms both standard finetuning and standard distillation across all datasets.
In other words, learning teacher’s rationales helps improving the performance of student models.
Distilling step-by-step is able to surpass the LLM using smaller train datasets in most cases.
For instance, on the e-SNLI dataset with standard finetuning setting (upper left), distilling step-by-step outperforms the LLM using only 0.1% of the dataset. On SVAMP dataset with standard distillation setting (lower right), it does not outperfrom the LLM. Netertheless, it’s worth noting that distilling step-by-step achieve similar performance to the LLM with a smaller model.
Conclusion
- Distilling step-by-step is a training mechanism that extacts both rationales and answers from LLMs and utilize them as supervision in KD.
- This approach reduces the required training dataset and model size to achieve performance equivalent to or better than that of LLMs.
- It enables deployment independent of LLMs.
References
[1] Hsieh, Cheng-Yu, et al. “Distilling step-by-step! outperforming larger language models with less training data and smaller model sizes.” arXiv preprint arXiv:2305.02301 (2023).
Leave a comment