Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity

William Fedus, Barret Zoph, Noam Shazeer
[arXiv] [Google Scholar] [DBLP] [Citeseer]
Read: 27 September 2021

arXiv 2101.03961 cs.LG
2021
Note(s): mixture of experts, transformer model, machine learning, neural network, sparse model, NLP, google
Papers: shazeer:arxiv:2017

Observes that the mixture of experts approach suffers from problems in complexity, communication costs and training instabilities. Simplifies it by replacing the “top-k” approach from the mixture of experts (where the results from k experts are combined) with a switch that selects just one expert. That is, it uses k=1. This preserves model quality, reduces routing computation and performs better.

[Does the improvement come at the cost of needing more experts so that there is more redundancy/overlap between experts?]

When implemented on a TPU, all tensor shapes must be determined at compile time and all matrix operations are dense but the routing decisions are dynamic.

An important tradeoff between precision and efficiency is that values sent between devices are represented using bfloat16 but values used within devices are represented using float32.

Switch transformers scale well (wrt number of parameters) because you can have many different experts without increasing computational cost much. The only cost of more experts is choosing which expert to use. And the extra capacity (more parameters) means that you get better results from a sparse switch-transformer model than from a dense model. But, you have to limit the size of each expert so that it fits in a single device.

Figure 9 (page 16) gives a nice picture of five ways that model weights and data can be split over cores. The impact of these on compute patterns, communication patterns and various capacities are then discussed.


  • GShard: Scaling giant models with conditional computation and automatic sharding [lepikhin:arxiv:2020]