Skip to content

Instantly share code, notes, and snippets.

View vgoklani's full-sized avatar

Vishal Goklani vgoklani

View GitHub Profile
@vgoklani
vgoklani / nanochat_nans.py
Last active December 25, 2025 02:39
Very specific case that leads to NaNs on the newer versions of pytorch
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
VOCAB = 65536
@vgoklani
vgoklani / lr_scheduler.py
Created December 27, 2024 02:37 — forked from andrewnc/lr_scheduler.py
God's Chosen Schedule
import math
import torch
from torch.optim.lr_scheduler import _LRScheduler
from dataclasses import dataclass
from typing import List
@dataclass
class SchedulePhase:
"""Defines a phase in the learning rate schedule"""
percent: float # Percentage of total steps this phase covers
import torch
import torch.nn.functional as F
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / x.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
@vgoklani
vgoklani / pipeline_parallel.py
Created October 2, 2024 10:45 — forked from 3outeille/pipeline_parallel.py
Self contained example of how pipeline parallel works (AFAB and 1F1B) in 200 LOC
#VERBOSE=0 torchrun --nproc_per_node 3 self_contained_pp_LOC.py
import os, random, numpy as np, torch, torch.nn as nn, torch.distributed as dist, torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
STEP, local_rank, world_size, verbose = 0, int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]), os.environ.get("VERBOSE", "0") == "1"
def set_all_seed(seed):
@vgoklani
vgoklani / resnet_mlx.py
Created September 8, 2024 13:07 — forked from awni/resnet_mlx.py
MLX ResNet18 Inference Benchmark
from huggingface_hub import snapshot_download
import mlx.core as mx
import mlx.nn as nn
import time
class Block(nn.Module):
def __init__(self, in_dims, dims, stride=1):
super().__init__()
@vgoklani
vgoklani / pypdfjs.py
Created May 5, 2024 14:39 — forked from mara004/pypdfjs.py
PDF rendering with pdf.js, from Python
# SPDX-FileCopyrightText: 2023 mara004
# SPDX-License-Identifier: CC-BY-4.0 OR Apache-2.0
# See also https://github.com/extremeheat/JSPyBridge/blob/master/examples/python/pdfjs.py
# Py-Depends: pillow, javascript >= 1.1.0 (jspybridge)
# Js-Depends: pdfjs-dist, canvas
# Use `python -m pip install` and `python -m javascript --install`
import argparse
@vgoklani
vgoklani / modeling_mixtral.py
Created May 5, 2024 13:31 — forked from kalomaze/modeling_mixtral.py
Fixed Mixtral training code for HF Transformers
# coding=utf-8
# Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@vgoklani
vgoklani / torch_ddp_verify.py
Created April 17, 2024 22:21 — forked from jxmorris12/torch_ddp_verify.py
verify parameter weights & gradients in pytorch
def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None:
if hasattr(model, "module"):
model = model.module
world_size = get_world_size()
for name, param in model.named_parameters():
gathered_param = gather(param).reshape((world_size, -1))
absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs()
rank_params_eq = (absolute_diffs < atol).all()
assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}"
from typing import Optional
import torch
import torch.nn as nn
@torch.no_grad()
def measure_time_device(
model: nn.Module,
dtype: Optional[torch.dtype] = torch.float32,
num_repeats: Optional[int] = 100,
cimport cython
import numpy as np
cimport numpy as np
from sklearn.metrics import f1_score
@cython.boundscheck(False)
@cython.wraparound(False)
def f1_opt(np.ndarray[long, ndim=1] label, np.ndarray[double, ndim=1] preds):