In recent years, many of the best-performing models in the field of natural-language processing (NLP) have been built on top of BERT language models. Pretrained on large corpora of (unlabeled) public texts, BERT models encode the probabilities of sequences of words. Because a BERT model begins with extensive knowledge of a language as a whole, it can be fine-tuned on a more targeted task — like question answering or machine translation — with relatively little labeled data.
BERT models, however, are very large, and BERT-based NLP models can be slow — even prohibitively slow, for users with limited computational resources. Their complexity also limits the length of the inputs they can take, as their memory footprint scales with the square of the input length.
At this year’s meeting of the Association for Computational Linguistics (ACL), my colleagues and I presented a new method, called Pyramid-BERT, that reduces the training time, inference time, and memory footprint of BERT-based models, without sacrificing much accuracy. The reduced memory footprint also enables BERT models to operate on longer text sequences.
BERT-based models take sequences of sentences as inputs and output vector representations — embeddings — of both each sentence as a whole and its constituent words individually. Downstream applications such as text classification and ranking, however, use only the complete-sentence embeddings. To make BERT-based models more efficient, we progressively eliminate redundant individual-word embeddings in intermediate layers of the network, while trying to minimize the effect on the complete-sentence embeddings.
We compare Pyramid-BERT to several state-of-the-art techniques for making BERT models more efficient and show that we can speed inference up 3- to 3.5-fold while suffering an accuracy drop of only 1.5%, whereas, at the same speeds, the best existing method loses 2.5% of its accuracy.
Moreover, when we apply our method to Performers — variations on BERT models that are specifically designed for long texts — we can reduce the models’ memory footprint by 70%, while actually increasing accuracy. At that compression rate, the best existing approach suffers an accuracy dropoff of 4%.
A token’s progress
Each sentence input to a BERT model is broken into units called tokens. Most tokens are words, but some are multiword phrases, some are subword parts, some are individual letters of acronyms, and so on. The start of each sentence is demarcated by a special token called — for reasons that will soon be clear — CLS, for classification.
Each token passes through a series of encoders — usually somewhere between four and 12 — each of which produces a new embedding for each input token. Each encoder has an attention mechanism, which decides how much each token’s embedding should reflect information carried by other tokens.
For instance, given the sentence “Bob told his brother that he was starting to get on his nerves,” the attention mechanism should pay more attention to the word “Bob” when encoding the word “his” but “brother” when encoding the word “he”. It’s because the attention mechanism must compare every word in an input sequence to every other that a BERT model’s memory footprint scales with the square of the input.
As tokens pass through the series of encoders, their embeddings factor in more and more information about other tokens in the sequence, since they’re attending to other tokens that are also factoring in more and more information. By the time the tokens pass through the final encoder, the embedding of the CLS token ends up representing the sentence as a whole (hence the CLS token’s name). But its embedding is also very similar to those of all the other tokens in the sentence. That’s the redundancy we’re trying to remove.
The basic idea is that, in each of the network’s encoders, we preserve the embedding of the CLS token but select a representative subset — a core set — of the other tokens’ embeddings.
Embeddings are vectors, so they can be interpreted as points in a multidimensional space. To construct core sets we would, ideally, sort embeddings into clusters of equal diameter and select the center point — the centroid — of each cluster.
Unfortunately, the problem of constructing a core set that spans a layer of a neural network is NP-hard, meaning that it’s impractically time consuming.
As an alternative, our paper proposes a greedy algorithm that selects n members of the core set at a time. At each layer, we take the embedding of the CLS token, and then we find the n embeddings farthest from it in the representational space. We add those, along with the CLS embedding, to our core set. Then we find the n embeddings whose minimum distance from any of the points already in our core set is greatest, and we add those to the core set.
We repeat this process until our core set reaches the desired size. This is provably an adequate approximation of the optimal core set.
Finally, in our paper, we consider the question of how large the core set of each layer should be. We use an exponential-delay function to determine the degree of attenuation from one layer to the next, and we investigate the trade-offs between accuracy and speedups or memory reduction that result from selecting different rates of decay.
Acknowledgements: Ashish Khetan, Rene Bidart, Zohar Karnin