TIL ·

optimize_for_inference isn’t a function that is usually presented in beginner materials for pytorch. I just read about it in Aurélien Géron’s new edition of his hands on ML book. This edition (considered a 1st edition) focuses on pytorch in Part II. I loved all his previous editions because he really takes beginners on a deep-dive on DL and generously sprinkles links to recent/important papers on each of the topics. In this case, he explains that the function optomize_for_inference performs some fusion of batch normalization with previous linear layers.

Batch normalize will learn γ\gamma and β\beta which will scale and shift the inputs to the BN layer after normalizing them (using μ\mu and σ\sigma). So, if the previous layer computes XW+bXW + b, then the BN layer will compute γ(XW+bμ)/σ+β\gamma \oplus (XW + b - \mu) / \sigma + \beta. The BN layer can be eliminated if we let W=γW/σW' = \gamma \oplus W / \sigma and b=γ(bμ)/σ+βb' = \gamma \oplus (b - \mu) / \sigma + \beta.

It appears that fusion with CNN layers has been a well established technique for runtime/inference and it’s even possible to reduce the memory requirements during training by recognizing that fusion will occur.