Based on Small-scale proxy?

By Simo Ryu (twitter, github)

Special thanks to Lucas Beyer, Jeremy Bernstein, and Satoki Ishikawa for the feedbacks! If you have suggestions reach me out with above links!

Goal

Goal here is to provide go-to values & scaling rules for large-scale training, to minimize the hparams to sweep for! You want to search for optimal hparam in small scale, with smaller model, smaller data, and smaller batch size. e.g. your small-scale proxy would have:

model_frac : width_small_scale / width_large_scale = 0.1

data_frac: data_small_scale / data_large_scale = 0.1

batch_scale: large_batch_size / small_batch_size = 16

And you would sweep LR, weight STD, etc… and you wish it would transfer to large scale near-optimally! You would use the following scaling equations…

Small Scale Proxy Large Scale Reference
LR (Assuming AdamW) [2^-8 ~ … ] : Sweep lr_small * (batch_scale**0.5) * model_frac * (data_frac ** 0.24)
For (width * width matrix). For others, typically lr_small * (batch_scale**0.5) * (data_frac ** 0.24) but it depends.

Residual-branches can be further scaled by factor of 1/sqrt(L) but this is debatable | https://openreview.net/pdf?id=0ksNeD1SJT https://openreview.net/forum?id=goEdyJ_nVQI

On infinite-depth rescaling, https://arxiv.org/abs/2310.02244 | | Initialization STD | [0.1 , 0.01, … ] : Sweep | muP (1) init_small * model_frac**0.5 modula / agd (2) ortho(fan_in, fan_out) *sqrt(fan_out / fan_in)

For readout layers (such as branches, final linear layer), use 0-init as they are commonly used. | https://arxiv.org/abs/2203.03466

https://arxiv.org/abs/2304.05187, https://arxiv.org/abs/2405.14813 Jeremy suggests to do (2), which have no hparam for large matrices and it is (at infinity) same elementwise-$\theta(1/\sqrt(\text{fanin})$ of (1) with specific init_small = 1.0/sqrt(fan_in_small) (Read this)

For branches, some works rescale by extra 1/sqrt(L) , but ime it doesnt matter as much in practical sense (because your depth don’t grow by factor of 16 in typical settings). Again, empirical results are found in depth-muP paper | | LR scheduler | Recommended By Keller Jordan: trapezoidal learning rate schedule (10% cooldown). (Note: In case we are using muP, might not need warmup) | trapezoidal learning rate schedule (10% cooldown) | *https://arxiv.org/abs/2405.18392

https://arxiv.org/abs/2305.19268 | | AdamW weight decay | 0.1 * model_frac / data_frac | 0.1 (Commonly used in large scale) | https://arxiv.org/abs/2405.13698v1, 0.1 was pointed out to be good for quantization https://arxiv.org/abs/2305.19268 | | Adam Epsilon | default, 1e-9 ? | 1e-12 ~ 1e-15 ? (base_epsilon * model_frac ** 1.5) | https://openreview.net/pdf?id=0ksNeD1SJT Recommends to use smaller epsilon value, as it turns out to be quite critical than initially anticipated. | | Data Size | 40B | e.g., 7T (data_small / data_frac) | How should you set your dataset size? Well it depends on your compute budget and how you prefer over inference budget. Read https://arxiv.org/abs/2401.00448 | | Batch Size | [1M, 2M … ] : Variable dependent on hardware / software optimization level | batch_small * batch_scale | How should you set your batch size? The rule of thumb would be (IMO): As small as possible under assumption you max your MFU. Read https://arxiv.org/abs/1812.06162 and https://arxiv.org/abs/2001.08361 if you have time. batch_scale can tolerate size of 16 |


On LR as width & data size grows:

Large part of muP is about scaling down learing rate as your model gets bigger. There is two complementary picture to this.