Last active
December 29, 2025 11:28
-
-
Save JGalego/e9adaac4f0f707a0f3c733dd84c22920 to your computer and use it in GitHub Desktop.
Fill sorries and prove theorems using Aristotle by Harmonic ποΈπ¦πΊ
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "aristotlelib >= 0.6.0", | |
| # ] | |
| # /// | |
| r""" | |
| Fill sorries and prove theorems using Aristotle by Harmonic ποΈπ¦πΊ | |
| This script loads theorem examples from a JSON file and processes them | |
| using different task types: fill, prove, or direct. | |
| Getting Started: | |
| 1. Clone the datasets repository: | |
| git clone https://github.com/harmonic-ai/datasets | |
| 2. Run the tester with a dataset file: | |
| python aristotle_tester.py datasets/minif2f/validation.json \ | |
| --task prove \ | |
| --limit 5 | |
| Example usage: | |
| # Process 10 examples from the test set using fill task | |
| python aristotle_tester.py datasets/minif2f/test.json \ | |
| --task fill \ | |
| --limit 10 | |
| # Process examples from the training set using prove task | |
| python aristotle_tester.py datasets/minif2f/train.json \ | |
| --task prove \ | |
| --limit 3 | |
| """ | |
| # Standard imports | |
| import argparse | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import random | |
| from datetime import datetime | |
| from pathlib import Path | |
| # Library imports | |
| import aristotlelib | |
| # Configure logger | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s %(levelname)s: %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| async def process_example(example: dict, task_type: str, output_dir: Path): | |
| """ | |
| Process a single example from the dataset. | |
| Args: | |
| example: Dictionary with 'id', 'formal', 'natural', and 'name' keys | |
| task_type: Either 'fill', 'prove', 'direct', or 'all' | |
| output_dir: Directory to save results | |
| """ | |
| example_id = example.get('id', 'unknown') | |
| formal_code = example.get('formal', '') | |
| natural_language = example.get('natural', '') | |
| logger.info("Processing: %s - %s", example_id, example.get('name', 'N/A')) | |
| # Create temporary Lean file in current directory (to access lakefile.lean) | |
| tmp_dir = Path('./tmp_examples') | |
| tmp_dir.mkdir(exist_ok=True) | |
| tmp_file_path = tmp_dir / f"{example_id}.lean" | |
| with open(tmp_file_path, 'w', encoding='utf-8') as f: | |
| f.write(formal_code) | |
| try: | |
| # Run the appropriate task | |
| if task_type == 'fill': | |
| logger.info("Running 'fill' task...") | |
| solution_path = await aristotlelib.Project.prove_from_file( | |
| input_file_path=tmp_file_path | |
| ) | |
| elif task_type == 'prove': | |
| logger.info("Running 'prove' task...") | |
| solution_path = await aristotlelib.Project.prove_from_file( | |
| input_content=natural_language, | |
| project_input_type=aristotlelib.ProjectInputType.INFORMAL, | |
| ) | |
| elif task_type == 'direct': | |
| logger.info("Running 'direct' task...") | |
| logger.info("Natural language: %s", natural_language) | |
| # For direct task, use both natural and formal | |
| solution_path = await aristotlelib.Project.prove_from_file( | |
| input_content=natural_language, | |
| formal_input_context=tmp_file_path, | |
| project_input_type=aristotlelib.ProjectInputType.INFORMAL, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown task type: {task_type}") | |
| # Copy result to output directory with example ID, task, and timestamp | |
| if solution_path and os.path.exists(solution_path): | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_file = output_dir / f"{example_id}_{task_type}_{timestamp}.lean" | |
| with open(solution_path, 'r', encoding='utf-8') as src, \ | |
| open(output_file, 'w', encoding='utf-8') as dst: | |
| dst.write(src.read()) | |
| logger.info("β Solution saved to: %s", output_file) | |
| return { | |
| 'id': example_id, | |
| 'task': task_type, | |
| 'status': 'success', | |
| 'output': str(output_file) | |
| } | |
| logger.warning("β No solution generated") | |
| return { | |
| 'id': example_id, | |
| 'task': task_type, | |
| 'status': 'failed', | |
| 'error': 'No solution generated' | |
| } | |
| except Exception as exc: # pylint: disable=broad-exception-caught | |
| logger.error("β Error processing example: %s", str(exc)) | |
| return {'id': example_id, 'task': task_type, 'status': 'error', 'error': str(exc)} | |
| finally: | |
| # Clean up temporary file | |
| if tmp_file_path.exists(): | |
| tmp_file_path.unlink() | |
| async def main(): | |
| """Main function to process Lean theorems from JSON dataset.""" | |
| parser = argparse.ArgumentParser( | |
| description=__doc__, | |
| formatter_class=argparse.RawDescriptionHelpFormatter | |
| ) | |
| parser.add_argument( | |
| 'json_file', | |
| type=str, | |
| help='Path to JSON file containing theorem examples' | |
| ) | |
| parser.add_argument( | |
| '--task', | |
| type=str, | |
| choices=['fill', 'prove', 'direct', 'all'], | |
| default='fill', | |
| help='Task type to run: fill (fill sorries), ' | |
| 'prove (prove theorems), direct (use natural+formal), ' | |
| 'all (run all tasks) (default: fill)' | |
| ) | |
| parser.add_argument( | |
| '--output-dir', | |
| type=str, | |
| default='./results', | |
| help='Directory to save results (default: ./results)' | |
| ) | |
| parser.add_argument( | |
| '--limit', | |
| type=int, | |
| default=None, | |
| help='Limit number of examples to process (default: all)' | |
| ) | |
| parser.add_argument( | |
| '--random', | |
| action='store_true', | |
| help='Select examples randomly instead of sequentially' | |
| ) | |
| parser.add_argument( | |
| '--id', | |
| type=str, | |
| default=None, | |
| help='Process only the example with this specific ID' | |
| ) | |
| args = parser.parse_args() | |
| # Load JSON file | |
| logger.info("Loading examples from: %s", args.json_file) | |
| with open(args.json_file, 'r', encoding='utf-8') as f: | |
| examples = json.load(f) | |
| logger.info("Found %d examples", len(examples)) | |
| # Filter by ID if specified | |
| if args.id: | |
| examples = [ex for ex in examples if ex.get('id') == args.id] | |
| if not examples: | |
| logger.error("No example found with ID: %s", args.id) | |
| return | |
| logger.info("Found example with ID: %s", args.id) | |
| else: | |
| # Shuffle examples if random flag is set | |
| if args.random: | |
| random.shuffle(examples) | |
| logger.info("Examples shuffled randomly") | |
| # Apply limit if specified | |
| if args.limit: | |
| examples = examples[:args.limit] | |
| logger.info("Processing %d examples", len(examples)) | |
| # Create output directory | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info("Results will be saved to: %s", output_dir) | |
| # Process all examples | |
| results = [] | |
| for example in examples: | |
| if args.task == 'all': | |
| # Run all task types | |
| for task in ['fill', 'prove', 'direct']: | |
| result = await process_example(example, task, output_dir) | |
| results.append(result) | |
| else: | |
| result = await process_example(example, args.task, output_dir) | |
| results.append(result) | |
| # Save summary | |
| summary_file = output_dir / 'summary.json' | |
| with open(summary_file, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, indent=2) | |
| # Log final statistics | |
| success_count = sum(1 for r in results if r['status'] == 'success') | |
| failed_count = sum(1 for r in results if r['status'] == 'failed') | |
| error_count = sum(1 for r in results if r['status'] == 'error') | |
| logger.info("SUMMARY:") | |
| logger.info("> Total: %d", len(results)) | |
| logger.info("> Success: %d", success_count) | |
| logger.info("> Failed: %d", failed_count) | |
| logger.info("> Errors: %d", error_count) | |
| logger.info("Summary saved to: %s", summary_file) | |
| if __name__ == '__main__': | |
| try: | |
| asyncio.run(main()) | |
| except KeyboardInterrupt: | |
| logger.info("Interrupted by user. Exiting gracefully...") | |
| raise SystemExit(0) from None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment