Toloka Team
Knowledge distillation: a way to make a large model more efficient and accessible
With the development of machine learning, more and more deep learning models are appearing, with the number of parameters exceeding several hundred billion or even reaching a trillion. Some of the largest models right now include M6-10T, which has 10 trillion parameters, GLaM (Generalist Language Model) by Google, which has 1.2 trillion parameters, and GPT-4, which, according to some sources, has 1.76 trillion parameters.
Training and deploying such models requires considerable computational power, often necessitating specialized hardware and electricity. In addition, such models demand large amounts of memory to store and process their parameters, which makes them impractical for deployment in resource-limited environments.
Given the challenges, knowledge distillation of large language models presents a promising solution that makes the benefits of large-scale models more accessible and practical. By enabling smaller and sometimes even more efficient models that retain much of the performance of their larger counterparts, knowledge distillation helps bridge the gap between groundbreaking AI research and real-world applications. In this article, we discuss knowledge distillation training schemes and their use.
What is knowledge distillation in deep learning?
Knowledge distillation, first introduced in a paper Distilling the Knowledge in a Neural Network, by Geoffrey Hinton and his colleagues, is a technique for training primarily neural network-based machine learning models to transfer knowledge from a teacher model to a student model.
A teacher network is a powerful deep-learning model or ensemble of models skilled at capturing complex features and making accurate predictions. Such teacher models are used to train simpler student networks, which can be trained to produce results comparable to those of the teacher network.
Aside from making accurate predictions, during the learning process, a complex neural network is taught to generate meaningful and helpful representations of data, which can be called knowledge. This notion is at the core of knowledge distillation because the predictions reproduced by the student network through a distillation process are based on these thorough representations of data, or knowledge, stored by the teacher network in its hidden layers.
Knowledge distillation is necessary to carry out model compression since the initial pre-trained teacher model contains many parameters that require considerable resources to deploy, train, and utilize. Model compression through knowledge distillation helps decrease the memory and computational requirements of a large model without substantially reducing its overall productivity.
In large language models (LLMs), knowledge distillation has facilitated the transfer of more abstract characteristics, such as the model's style, reasoning capabilities, and alignment with human preferences and values. Knowledge distillation techniques go beyond simply copying the outputs of teacher models; they strive to mimic the underlying "thought processes" of these models.
Knowledge distillation represents one of the most effective ways to reduce the size of a model as well as its processing speed. As a result of model compression, the intricate and large deep neural network is condensed into a smaller and more simplified one, while preserving the accuracy and performance of the initial model.
A scaled-down model, trained to replicate the behavior of a heavy and accurate teacher model, achieves similar results to the teacher model, significantly benefiting in terms of size and speed due to its simplified architecture. The student model may even outperform the teacher in some cases, but its accuracy will be slightly poorer.
Teacher and student models, also known as the teacher-student framework, significantly contribute to the practical deployment of machine learning models by improving efficiency, performance, and generalization.
Step-by-step process of knowledge distillation
Training teacher model
At first, the teacher model learns information from a dataset to achieve high accuracy. Such a teacher model is typically represented by a deep neural network that achieves high accuracy by leveraging its enormous capacity to learn complex patterns in the data. Due to its sophistication, it can capture intricate nuances in the data, resulting in accurate predictions.
Generating soft labels
Before predicting a hard target (also known as a label), also referred to as a final output, deep neural networks usually tend to generate intermediate predictions or soft targets (labels). They serve as a principal data source for the student model training process in knowledge distillation.
So, following the training phase, the teacher model is used to produce soft targets for the training data. Soft targets, unlike hard ones, provide rich information, including the relative probabilities of incorrect classes, which help the student model learn more about the underlying data structure. Hard targets indicate only the correct or highest probability class, whereas soft targets provide a distribution of probabilities across all classes.
Thanks to the level of detail of the data provided in the soft targets, the student model requires a smaller amount of data than the original teacher model. A higher learning rate is also applied during student model training than the original one.
Training student model
The student model is smaller and sometimes even more efficient than the teacher model. It has fewer parameters, which makes it suitable for implementation in low-resource contexts. The loss function is what the student model tries to minimize during training. It measures how well the student model is doing.
In knowledge distillation, two loss functions are applied to bring the student model's problem-solving process closer to that of a teacher model. One is a standard loss function, and the other is called distillation loss.
The standard one in the context of knowledge distillation is called hard loss, which compares the final result of the model with ground truth in supervised learning and with the original data in self-supervised learning. Distillation loss or soft loss compares the student model's soft labels with the teacher model's soft labels using Kullback-Leibler divergence.
Training the smaller model involves using the teacher’s detailed predictions (soft labels) and the initial dataset to guide the student. The smaller student model learns to make as little difference as possible between its own predictions and the soft labels formed by the teacher network. The student model goes through many iterations (epochs) of training, each time adjusting its parameters to better match the teacher’s predictions and the actual labels.
Types of knowledge distillation
Response-based knowledge distillation
Response-based knowledge distillation is a technique that focuses on transferring knowledge from a teacher model to a student model by having the student model mimic the final output layer of the teacher model. This method leverages the rich information contained in the teacher's predictions to train a more efficient student model.
The student model learns to imitate the teacher model's predictions by minimizing a distillation loss. This ensures that the student captures the nuanced information present in the teacher's outputs. The goal of training is to minimize the loss, which means that during several epochs of training, the small student model will get better at generating the same results as the teacher.
Feature-based knowledge distillation
Feature-based knowledge distillation leverages the internal representations or features learned by a teacher model in its intermediate layers to train a student model. This approach goes beyond focusing solely on the output layer and instead utilizes the rich, hierarchical features captured in the intermediate layers of the teacher model. Intermediate layers of the teacher model extract features crucial for the model's final prediction. These features contain valuable information about the data's structure and patterns.
Relation-based knowledge distillation
Relation-based knowledge distillation is a technique that captures and transfers the relationship knowledge between data samples and layers within the neural network. This method complements response-based and feature-based knowledge distillation by focusing on the interactions and relationships between different features, rather than the features themselves or the final outputs.
Knowledge distillation techniques
The knowledge distillation approach implies the possibility of training the student on unlabeled data, but to maximize quality it is necessary to combine the teacher's predictions with the labeled data. According to Knowledge Distillation: A Survey to* *Qualitatively Distill Knowledge, the number of parameters in the student model architecture should not be decreased by more than double the number of teacher parameters.
Knowledge distillation involves transferring knowledge from a larger, well-trained model (teacher) to a smaller, more efficient model (student). There is a model capacity gap between such models that has to be minimized for the student model to generate the same accurate results as the teacher model.
Several schemes and algorithms are designed to facilitate this process, each with specific methods for improving the student model's performance by leveraging the teacher model's knowledge. There are three primary modes available for training student and teacher models:
Offline distillation;
Online distillation;
Self-distillation.
Offline knowledge distillation
Offline knowledge distillation refers to the traditional approach where the teacher model is trained first. Then the student model is trained separately using the soft labels generated by the teacher. Knowledge is distilled from a large pre-trained teacher model into a student, while the teacher itself remains unchanged with its weights frozen.
Online knowledge distillation
There may be circumstances when a large pre-trained teacher is not available for a given task, or the teacher model is so huge that there is insufficient storage or processing capacity on the device or in the cloud. Due to this, training is either completely impossible or takes more time than training the desired student model.
To resolve this issue online knowledge distillation may be used. It refers to a method where the teacher and student models are trained simultaneously, with the student learning from the teacher dynamically during training. Online distillation requires a teacher model of the same architecture as the student.
Self-distillation
Previous knowledge distillation methods included two separate models, whereas self-distillation is a variant of knowledge distillation in which a single model acts as both the teacher and the student.
In self-distillation, knowledge is transferred from deeper layers of the network to shallow layers of the same network. Shallow classifiers are placed on different layers of the model and removed from that network after the model is trained and functional. Classifiers at deeper layers act as teacher models and train the shallow layers, which, in such cases, represent student models.
Self-distillation comprises two distinct groups of methods. Since the teacher and student are trained simultaneously, the first group of methods can be referred to as an extension of online knowledge distillation. In this type of self-distillation, the information that is accumulated in the model during training is used to further improve the quality of predictions of the same model. In other words, the model learns from its own predictions generated during earlier stages of training.
In the second group of methods, experts select a particular neural network architecture and train one model. Then, using knowledge distillation from the previously trained model, they train the same model. This group of self-distillation methods can also be called offline distillation when a new model of the same architecture is obtained from the previously trained model.
Since the models are identical in their architecture, the teacher cannot provide the student model with new skills, and therefore the final quality of the model will not improve. Hence, this method of self-distillation is applied to explore the knowledge dissociation proposed by Hinton.
Knowledge distillation applications
Knowledge distillation has no boundaries when it comes to its application areas. It can be utilized in natural language processing (NLP) for LLM compression. For example, DistilBERT is a lightweight compressed counterpart of a larger BERT language model. It reduced the size of the BERT model by 40% while retaining 97% of its performance and being 60% faster.
DistilBERT was created to address several key challenges associated with large pre-trained language models like BERT. It particularly focuses on the need for efficiency in terms of computational resources and memory usage. The model architecture of DistilBERT is optimized to retain the most important aspects of BERT’s performance while reducing redundant parameters and layers.
Apart from NLP, knowledge distillation is employed in the computer vision domain. For example, the paper Cogni-Net: Cognitive Feature Learning Through Deep Visual Perception describes the CogniNet model that employs knowledge distillation in the field of brain signal classification. In this approach, a BiLSTM (Bidirectional Long-Short Term Memory) model, which is adept at handling sequential data, is trained to classify brain EEG (electroencephalogram) signals. This training is enhanced through knowledge distillation, where a pre-trained deep convolutional vision network acts as the teacher model.
The distillation process helps the BiLSTM model achieve better accuracy in classifying brain signals. Research demonstrated that leveraging pre-trained vision models to distill knowledge into a recurrent model can effectively bridge the gap between visual perception and brain signal interpretation. The findings highlight the potential of knowledge distillation in enhancing inter-modal learning and understanding complex cognitive processes.
Knowledge distillation algorithms
Several algorithms have been proposed to facilitate the process of knowledge distillation, each with its unique approach and use cases. The following are some prominent algorithms for effective distillation.
Attention-based knowledge distillation
Knowledge distillation involves transferring knowledge from a large, complex teacher model to a smaller, simpler student model. Attention mechanisms can enhance this process by helping the student model learn not just from the teacher model's final outputs but also from the intermediate representations where attention is applied.
Adversarial learning
Adversarial learning enables teacher and student models to capture the underlying data distribution better. Generative Adversarial Networks (GANs) can be used to augment the training dataset by generating additional realistic data samples. This augmented data can help the student model learn better by providing more diverse training examples, thus improving generalization and performance.
Here’s how adversarial learning and GANs come into play in knowledge distillation. The student model that acts as a generator in GAN tries to create answers (outputs) that look just like the teacher’s answers. A discriminator checks how similar the student’s answers are to the teacher’s. The student gets better by fooling the discriminator into thinking its answers are from the teacher. By using adversarial learning, the student gets better at copying the teacher, making its predictions more accurate over time.
Data-free knowledge distillation
Knowledge Distillation typically involves training a student model using the same data used to train a teacher model. However, this data might not always be available. Being a complex network, the teacher model often requires a vast amount of data for training, which might be impractical to share, or the data used to train the teacher model (such as patient records) cannot be shared due to privacy concerns.
To address these issues, data-free knowledge distillation techniques have been developed. These methods generate synthetic data that mimics the data distribution of the original training samples used for the teacher model. One such method is the DeepInversion for Object Detection (DIODE) framework, which diverges from traditional GAN-based image synthesis methods. Unlike GANs, which use a generator and discriminator to create realistic images, DeepInversion directly optimizes the images to match the internal feature statistics of the teacher model.
Benefits of knowledge distillation
Knowledge distillation offers several key benefits in the context of machine learning and, more specifically, for complex deep learning models.
Reduced model size
Model compression is perhaps one of the fundamental challenges of knowledge distillation, which made researchers start looking into the necessity of developing such a technology. Knowledge distillation indeed significantly reduces the size of the model.
Doing so makes it more feasible to deploy student models capable of delivering the same results as original teacher models on devices with limited storage and computational power.
Small models are also capable of faster inference resulting from knowledge distillation. This is because they can process data more quickly, leading to quicker response times in applications and real-time systems.
Resource effficiency
Training a student model using knowledge distillation requires less data, so it is less resource-intensive than training a large model from scratch. Smaller models also consume less power and memory, which is crucial for applications with limited memory availability. This all leads to reduced computational and energy consumption, i.e., overall lower training costs.
Knowledge distillation allows the creation of multiple models for various types of tasks without the need for extensive computational resources. Smaller student models can also be scaled more easily across multiple devices or cloud instances.
Improved performance
Distilled models retain much of the accuracy and performance of the larger models, even with significant reductions in size. The student model can capture essential knowledge from the teacher model, often leading to better generalization on unseen data than a previously trained model.
Knowledge distillation is an effective way to make sophisticated
models more accessible
Knowledge distillation is a powerful and versatile technique for transferring knowledge from complex, high-performing teacher models to smaller, more efficient student models. Leveraging various forms of knowledge—response-based, feature-based, and relation-based—knowledge distillation allows the student model to capture not only the final predictions but also the intricate feature representations and relationships within the data.
Knowledge distillation solves one of the most pressing challenges: deploying sophisticated models in resource-constrained environments. By transferring knowledge from large, complex models to smaller, more efficient ones, knowledge distillation retains high performance while significantly reducing computational requirements.
In other words, knowledge distillation is an indispensable tool for employing complex models in resource-constrained environments such as mobile devices, IoT systems, real-time applications, and even on-edge devices with little computing power.
It allows for democratizing advanced AI, making sophisticated technology accessible across diverse platforms and use cases. As the demand for AI continues to grow, the importance of techniques like knowledge distillation will only increase, enabling new applications that were previously constrained by computational limitations.
Article written by:
Toloka Team
Updated:
May 22, 2024