Skip to content

Instantly share code, notes, and snippets.

@cosminscn
Last active August 28, 2025 01:03
Show Gist options
  • Select an option

  • Save cosminscn/495d625ab61e036e0c56103b8daf07cb to your computer and use it in GitHub Desktop.

Select an option

Save cosminscn/495d625ab61e036e0c56103b8daf07cb to your computer and use it in GitHub Desktop.
Distributed training notes - aug 27 25

A100 - spec 312 TFLOPS/s

40GB? 80GB HBM ram 20MB cache

Large model run

  • deepzero3 should read the deep speed paper, looks like they did as baseline model parallelism with bs2?

  • bloom blog https://huggingface.co/blog/bloom-megatron-deepspeed

  • activations

    • 12(input/proj/attention/nonlin) x hidden_dim x local_batch x seq_length x transformer_layers x 2(activation size)
  • params

    • transformer_layers * 12 (2 x 4 hidden_dim mlp + 4 hidden_dim) * hidden_dim * hidden_dim x 2? or x4?

activations_per_layer / params_per_layer == local_batch x seq_len / hidden_dim == 2 * 512 / 1024 == 1???? seems high?

only hdim for big models is larger?

hmm do I need to compare per layer? or assume fsdp?

FSDP

FSDP forward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        forward pass for layer_i
        discard full weights for layer_i

FSDP backward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        backward pass for layer_i
        discard full weights for layer_i
        reduce-scatter gradients for layer_i

simple fsdp implementation with torch compile https://github.com/facebookresearch/capi/blob/main/fsdp.py

JAX course

GPU Blog

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment