Created
December 18, 2025 09:24
-
-
Save naufalso/919d3183272f98842cd7d06ae510b9cf to your computer and use it in GitHub Desktop.
Script to Sync Hugging Face trainer_state.json to a new WandB Run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import wandb | |
| import argparse | |
| import os | |
| import sys | |
| def sync_trainer_state(json_path, project_name, run_name=None, entity=None): | |
| """ | |
| Reads a Hugging Face trainer_state.json file and logs the history to WandB. | |
| """ | |
| # 1. Load the JSON content | |
| try: | |
| with open(json_path, 'r') as f: | |
| state_data = json.load(f) | |
| except FileNotFoundError: | |
| print(f"Error: The file '{json_path}' was not found.") | |
| sys.exit(1) | |
| except json.JSONDecodeError: | |
| print(f"Error: Failed to decode JSON from '{json_path}'.") | |
| sys.exit(1) | |
| # 2. Initialize WandB | |
| # We use reinit=True to ensure a clean run if called multiple times in a notebook | |
| run = wandb.init( | |
| project=project_name, | |
| name=run_name, | |
| entity=entity, | |
| reinit=True | |
| ) | |
| print(f"Started WandB run: {run.name}") | |
| # 3. Log Top-Level Summary/Metadata | |
| # These fields from the file define the final state of training | |
| summary_keys = [ | |
| "best_global_step", | |
| "best_metric", | |
| "best_model_checkpoint", | |
| "epoch", | |
| "global_step", | |
| "max_steps", | |
| "num_train_epochs", | |
| "total_flos", | |
| "trial_name", | |
| "trial_params" | |
| ] | |
| for key in summary_keys: | |
| if key in state_data: | |
| # Add to wandb summary (displayed at the top of the run page) | |
| wandb.run.summary[key] = state_data[key] | |
| # 4. Log History | |
| # The 'log_history' list contains the time-series data (loss, learning rate, etc.) | |
| if "log_history" in state_data: | |
| history = state_data["log_history"] | |
| print(f"Found {len(history)} log entries. Syncing...") | |
| for entry in history: | |
| # We copy the entry to avoid modifying the original data | |
| log_payload = entry.copy() | |
| # Hugging Face logs usually contain a 'step' key. | |
| # We extract it to use as the explicit x-axis for WandB. | |
| step = log_payload.get("step") | |
| # Log the dictionary to WandB | |
| if step is not None: | |
| wandb.log(log_payload, step=step) | |
| else: | |
| wandb.log(log_payload) | |
| else: | |
| print("Warning: No 'log_history' key found in the JSON file.") | |
| # 5. Finish the run | |
| wandb.finish() | |
| print("Sync complete. You can view the run at the URL above.") | |
| if __name__ == "__main__": | |
| # Setup command line arguments | |
| parser = argparse.ArgumentParser( | |
| description="Sync Hugging Face trainer_state.json to a new WandB Run" | |
| ) | |
| parser.add_argument("json_file", help="Path to the trainer_state.json file") | |
| parser.add_argument("--project", required=True, help="Name of the WandB Project") | |
| parser.add_argument("--run_name", default="imported-hf-run", help="Name for the WandB Run") | |
| parser.add_argument("--entity", default=None, help="WandB Entity (username or team name)") | |
| args = parser.parse_args() | |
| sync_trainer_state( | |
| json_path=args.json_file, | |
| project_name=args.project, | |
| run_name=args.run_name, | |
| entity=args.entity | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment