Created
February 13, 2026 22:04
-
-
Save HDCharles/b3a1d045373a5ed8e86e72d0151d6ea0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class GPTQModifier ... | |
| def compress_modules(self): | |
| """ | |
| Quantize modules which have been calibrated | |
| """ | |
| ### Not Distributed | |
| if not (dist.is_initialized() and dist.get_world_size() > 1): | |
| self.compress_module_list(list(self._num_samples.keys())) | |
| ### Distributed | |
| ## Assign modules to ranks | |
| ## Accumulate hessian on assigned rank | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| # load balancing: greedy bin packing | |
| module_list, rank_to_modules, module_to_rank = greedy_bin_packing( | |
| list(self._hessians.keys()), | |
| world_size, | |
| item_weight_fn=lambda mod: self._hessians[mod].shape[0], | |
| ) | |
| # send hessian info to target_rank for each module | |
| self._reduce_hessian_to_target_rank(module_list, module_to_rank) | |
| # Each rank compresses all modules it was assigned | |
| self.compress_module_list(rank_to_modules[rank]) | |
| ## Broadcast quantized parameters to other ranks | |
| self._broadcast_quantized_params(module_list, module_to_rank) | |
| def compress_module_list(self, module_list): | |
| for module in module_list: | |
| name = self._module_names[module] | |
| quant_args = getattr_chain(module, "quantization_scheme.weights") | |
| logger.info(f"Quantizing {name}") | |
| with ( | |
| torch.no_grad(), | |
| align_module_device(module), | |
| self._maybe_onload_hessian(module), | |
| CompressionLogger(module) as comp_logger, | |
| ): | |
| loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( | |
| module=module, | |
| quant_args=quant_args, | |
| hessians_dict=self._hessians, | |
| blocksize=self.block_size, | |
| percdamp=self.dampening_frac, | |
| ) | |
| # comp_logger.set_loss(loss) | |
| update_offload_parameter(module, "weight", quantized_weight) | |
| update_offload_parameter(module, "weight_scale", scale) | |
| update_offload_parameter(module, "weight_zero_point", zero_point) | |
| if g_idx is not None: | |
| update_offload_parameter(module, "weight_g_idx", g_idx) | |
| # self._hessians[module] already deleted by quantize_weight | |
| self._num_samples.pop(module, None) | |
| def _reduce_hessian_to_target_rank(self, module_list, module_to_rank): | |
| def get_hessian_data(module): | |
| device = get_execution_device(module_list[0]) if len(module_list) > 0 else "cpu" | |
| H = self._hessians[module] | |
| n = torch.Tensor([self._num_samples.get(module)]).to(device) | |
| H = H*n | |
| return H, n | |
| def comm_fn(data, target_rank): | |
| H, n = data | |
| h_comm = dist.reduce(H, op=dist.ReduceOp.SUM, dst=target_rank, async_op=True) | |
| n_comm = dist.reduce(n, op=dist.ReduceOp.SUM, dst=target_rank, async_op=True) | |
| return [h_comm, n_comm] | |
| def store_data(module, data): | |
| H, n = data | |
| H/=n | |
| self._hessians[module] = H | |
| _dist_comms_impl( | |
| module_list, | |
| module_to_rank, | |
| get_data_fn=get_hessian_data, | |
| comm_fns=comm_fn, | |
| store_data_fn=store_data, | |
| should_store_data_fn= lambda target_rank: dist.get_rank() == target_rank, | |
| context_fn=self._maybe_onload_hessian | |
| ) | |
| def _broadcast_quantized_params(self, module_list, module_to_rank): | |
| def get_params(module): | |
| weight = getattr_chain(module, "weight_quantized", None) | |
| weight_scale = getattr_chain(module, "weight_scale", None) | |
| weight_zero_point = getattr_chain(module, "weight_zero_point", None) | |
| data = [weight, weight_scale, weight_zero_point] | |
| weight_g_idx = getattr_chain(module, "weight_g_idx", None) | |
| if weight_g_idx is not None: | |
| data.append(weight_g_idx) | |
| return data | |
| def comm_params(data, src_rank): | |
| pending_comms = [] | |
| for datum in data: | |
| comm = dist.broadcast(datum, src=src_rank, async_op=True) | |
| pending_comms.append(comm) | |
| _dist_comms_impl( | |
| module_list, | |
| module_to_rank, | |
| get_data_fn=get_params, | |
| comm_fn=comm_params, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment