Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Created February 13, 2026 22:04
Show Gist options
  • Select an option

  • Save HDCharles/b3a1d045373a5ed8e86e72d0151d6ea0 to your computer and use it in GitHub Desktop.

Select an option

Save HDCharles/b3a1d045373a5ed8e86e72d0151d6ea0 to your computer and use it in GitHub Desktop.
class GPTQModifier ...
def compress_modules(self):
"""
Quantize modules which have been calibrated
"""
### Not Distributed
if not (dist.is_initialized() and dist.get_world_size() > 1):
self.compress_module_list(list(self._num_samples.keys()))
### Distributed
## Assign modules to ranks
## Accumulate hessian on assigned rank
rank = dist.get_rank()
world_size = dist.get_world_size()
# load balancing: greedy bin packing
module_list, rank_to_modules, module_to_rank = greedy_bin_packing(
list(self._hessians.keys()),
world_size,
item_weight_fn=lambda mod: self._hessians[mod].shape[0],
)
# send hessian info to target_rank for each module
self._reduce_hessian_to_target_rank(module_list, module_to_rank)
# Each rank compresses all modules it was assigned
self.compress_module_list(rank_to_modules[rank])
## Broadcast quantized parameters to other ranks
self._broadcast_quantized_params(module_list, module_to_rank)
def compress_module_list(self, module_list):
for module in module_list:
name = self._module_names[module]
quant_args = getattr_chain(module, "quantization_scheme.weights")
logger.info(f"Quantizing {name}")
with (
torch.no_grad(),
align_module_device(module),
self._maybe_onload_hessian(module),
CompressionLogger(module) as comp_logger,
):
loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
module=module,
quant_args=quant_args,
hessians_dict=self._hessians,
blocksize=self.block_size,
percdamp=self.dampening_frac,
)
# comp_logger.set_loss(loss)
update_offload_parameter(module, "weight", quantized_weight)
update_offload_parameter(module, "weight_scale", scale)
update_offload_parameter(module, "weight_zero_point", zero_point)
if g_idx is not None:
update_offload_parameter(module, "weight_g_idx", g_idx)
# self._hessians[module] already deleted by quantize_weight
self._num_samples.pop(module, None)
def _reduce_hessian_to_target_rank(self, module_list, module_to_rank):
def get_hessian_data(module):
device = get_execution_device(module_list[0]) if len(module_list) > 0 else "cpu"
H = self._hessians[module]
n = torch.Tensor([self._num_samples.get(module)]).to(device)
H = H*n
return H, n
def comm_fn(data, target_rank):
H, n = data
h_comm = dist.reduce(H, op=dist.ReduceOp.SUM, dst=target_rank, async_op=True)
n_comm = dist.reduce(n, op=dist.ReduceOp.SUM, dst=target_rank, async_op=True)
return [h_comm, n_comm]
def store_data(module, data):
H, n = data
H/=n
self._hessians[module] = H
_dist_comms_impl(
module_list,
module_to_rank,
get_data_fn=get_hessian_data,
comm_fns=comm_fn,
store_data_fn=store_data,
should_store_data_fn= lambda target_rank: dist.get_rank() == target_rank,
context_fn=self._maybe_onload_hessian
)
def _broadcast_quantized_params(self, module_list, module_to_rank):
def get_params(module):
weight = getattr_chain(module, "weight_quantized", None)
weight_scale = getattr_chain(module, "weight_scale", None)
weight_zero_point = getattr_chain(module, "weight_zero_point", None)
data = [weight, weight_scale, weight_zero_point]
weight_g_idx = getattr_chain(module, "weight_g_idx", None)
if weight_g_idx is not None:
data.append(weight_g_idx)
return data
def comm_params(data, src_rank):
pending_comms = []
for datum in data:
comm = dist.broadcast(datum, src=src_rank, async_op=True)
pending_comms.append(comm)
_dist_comms_impl(
module_list,
module_to_rank,
get_data_fn=get_params,
comm_fn=comm_params,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment