Training Deep Nets with Sublinear Memory Cost, Tianqi Chen, Bing Xu, Chiyuan Zhang, Carlos Guestrin, 2016arXiv preprint arXiv:1604.06174DOI: 10.48550/arXiv.1604.06174 - This paper introduced gradient checkpointing (also known as activation checkpointing or re-materialization) to reduce memory consumption during deep neural network training.
jax.checkpoint, JAX developers, 2024 - Official documentation for the jax.checkpoint (aliased as jax.remat) transformation, providing practical usage details and examples for memory optimization in JAX.
The Reversible Residual Network: Backpropagation Without Storing Activations, Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse, 2017Advances in Neural Information Processing Systems 30 (NIPS 2017) (Curran Associates Inc.) - This paper introduces Reversible Residual Networks, an architectural design that further optimizes memory by allowing perfect reconstruction of activations during the backward pass, extending the principles of recomputation.
ZeRO: Memory Optimizations Toward Training Trillion Parameter Models, Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He, 2020SC20: International Conference for High Performance Computing, Networking, Storage and Analysis (IEEE)DOI: 10.1109/SC41405.2020.00024 - This paper presents ZeRO, a comprehensive set of memory optimization techniques for large-scale distributed model training, which integrates activation checkpointing within its framework.