Created
December 8, 2025 14:16
-
-
Save Ademking/8a8cbf734ddca33714be88cfcd89b67b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def log_metrics(self, metrics: dict, prefix: str = "train") -> None: | |
| """ | |
| Log average metrics to Weights & Biases (wandb). | |
| Args: | |
| metrics (dict): Dictionary of metric names (str) to lists of tensor/float values. | |
| prefix (str): Prefix for metric names (e.g., "train" or "eval"). | |
| """ | |
| # FIX: Correct dictionary comprehension syntax and safe PyTorch-only calculation (faster) | |
| try: | |
| averaged_metrics = { | |
| f"{prefix}/{metric}": torch.stack([v.detach().cpu() for v in values]).mean().item() | |
| for metric, values in metrics.items() | |
| } | |
| # The logging code below should be moved inside the try block for proper scope | |
| if self.config.train_config.use_wandb: | |
| self.wandb_logger.log(averaged_metrics, step=self.epoch) | |
| except Exception as e: | |
| # Note: torch.stack must be imported (e.g., import torch) | |
| print(f"Error while logging to wandb: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment