Last active
September 20, 2025 04:25
-
-
Save jingwangsg/0ff7eafe92fa84ed8993957683ddbe9c to your computer and use it in GitHub Desktop.
ray_compress
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import os.path as osp | |
| import sys | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| import argparse | |
| import random | |
| import time | |
| import json | |
| import datetime | |
| import numpy as np | |
| from loguru import logger | |
| try: | |
| from tqdm import tqdm | |
| except ImportError: | |
| def tqdm(iterable, **kwargs): | |
| return iterable | |
| import ray | |
| def setup_logging(log_dir: str = "logs"): | |
| """Setup logging with loguru to both file and console.""" | |
| # Remove default logger | |
| logger.remove() | |
| # Create logs directory | |
| os.makedirs(log_dir, exist_ok=True) | |
| # Add console handler with INFO level | |
| logger.add( | |
| sys.stderr, | |
| format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{message}</cyan>", | |
| level="INFO" | |
| ) | |
| # Add file handler with DEBUG level | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| pid = os.getpid() | |
| log_filename = f"compress_{timestamp}_pid{pid}.log" | |
| log_path = os.path.join(log_dir, log_filename) | |
| logger.add( | |
| log_path, | |
| format="{time:YYYY-MM-DD HH:mm:ss} | PID:{process} | {level: <8} | {message}", | |
| level="DEBUG", | |
| rotation="100 MB", | |
| retention="7 days" | |
| ) | |
| logger.info(f"Logging initialized. Process PID: {pid}") | |
| logger.info(f"Detailed logs will be written to: {log_path}") | |
| def get_all_files(directory: str = ".") -> list[str]: | |
| """Get all files and symlinks in the directory recursively. | |
| Returns both regular files and symlinks. | |
| """ | |
| try: | |
| # Get all files (fd will follow symlinks with -L/--follow) | |
| # This will recursively follow symlinked directories and find all files | |
| result = subprocess.run( | |
| ["fd", "--type", "f", "--follow", "--base-directory", directory, "."], | |
| capture_output=True, | |
| text=True, | |
| check=True, | |
| ) | |
| files = result.stdout.strip().split("\n") if result.stdout.strip() else [] | |
| files = [f for f in files if f] | |
| logger.info(f"Found {len(files)} files/symlinks") | |
| return files | |
| except subprocess.CalledProcessError as e: | |
| logger.error(f"Error running fd command: {e.stderr}") | |
| logger.error(f"fd command failed with return code: {e.returncode}") | |
| raise RuntimeError(f"File discovery failed: {e}") | |
| except FileNotFoundError: | |
| logger.error("fd command not found. Please install fd-find package.") | |
| logger.error("On Ubuntu: sudo apt install fd-find") | |
| logger.error("On macOS: brew install fd") | |
| raise RuntimeError("fd command is required but not installed") | |
| def log_tar_contents(part_path: str, files: list[str]) -> None: | |
| """Log tar file contents to a temporary log file.""" | |
| part_name = os.path.basename(part_path) | |
| log_file = part_path.replace(".tar", "_contents.log") | |
| try: | |
| with open(log_file, "w") as f: | |
| f.write(f"TAR FILE: {part_name}\n") | |
| f.write(f"CONTAINS {len(files)} FILES:\n") | |
| f.write("-" * 50 + "\n") | |
| for file in sorted(files): | |
| f.write(f"{file}\n") | |
| f.write("-" * 50 + "\n\n") | |
| logger.info(f"Logged contents of {part_name} to {log_file}") | |
| except Exception as e: | |
| logger.warning(f"Failed to log contents for {part_name}: {e}") | |
| @ray.remote | |
| def get_tar_members_remote(tar_file: str) -> list[str]: | |
| """Ray remote function to get members of a single tar file.""" | |
| from loguru import logger | |
| logger.remove() | |
| logger.add(sys.stderr, level="INFO") | |
| try: | |
| logger.info(f"Verifying {os.path.basename(tar_file)}...") | |
| cmd = f"tar -tf {tar_file}" | |
| result = subprocess.run(cmd, capture_output=True, text=True, shell=True) | |
| files_in_tar = result.stdout.strip().splitlines() | |
| members = [file for file in files_in_tar if file] | |
| logger.info(f" Found {len(members)} files in {os.path.basename(tar_file)}") | |
| return members | |
| except Exception as e: | |
| logger.error(f"Failed to verify {tar_file}: {e}") | |
| return [] | |
| import os | |
| import tarfile | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| def create_tar_part(files: list[str], part_path: str, base_dir: str) -> bool: | |
| # Import logger locally to avoid pickle issues | |
| from loguru import logger | |
| """Create a tar archive for a subset of files. | |
| files: 相对路径列表 (relative to base_dir) | |
| base_dir: 作为 tar 内部相对路径的根目录 | |
| """ | |
| try: | |
| base = Path(base_dir).resolve() | |
| logger.debug(f"Creating tar part {os.path.basename(part_path)} with {len(files)} files") | |
| logger.debug(f"Base directory: {base}") | |
| except Exception as e: | |
| logger.error(f"[create_tar_part] Failed to resolve base directory {base_dir}: {e}") | |
| return False | |
| # 1) 处理路径,确保都是相对于base_dir的 | |
| rel_paths = [] | |
| failed_paths = [] | |
| for p in files: | |
| try: | |
| # 如果是相对路径,直接使用;如果是绝对路径,转换为相对路径 | |
| if os.path.isabs(p): | |
| p_abs = Path(p).resolve() | |
| try: | |
| rel = p_abs.relative_to(base) | |
| rel_paths.append(rel.as_posix()) | |
| except ValueError: | |
| logger.error(f"[create_tar_part] File outside base_dir: {p}") | |
| failed_paths.append(p) | |
| else: | |
| # 验证相对路径文件是否存在 | |
| full_path = base / p | |
| # Use lexists to check for symlinks too (exists() follows symlinks) | |
| if full_path.exists() or os.path.islink(str(full_path)): | |
| # 规范化路径,处理特殊字符 | |
| normalized_path = Path(p).as_posix() | |
| rel_paths.append(normalized_path) | |
| else: | |
| logger.warning(f"[create_tar_part] File not found: {p}") | |
| failed_paths.append(p) | |
| except Exception as e: | |
| logger.error(f"[create_tar_part] Error processing path {p}: {e}") | |
| failed_paths.append(p) | |
| if failed_paths: | |
| logger.error(f"[create_tar_part] {len(failed_paths)} files failed path processing") | |
| if len(failed_paths) <= 10: | |
| for fp in failed_paths: | |
| logger.error(f" Failed path: {fp}") | |
| else: | |
| logger.error(f" First 5 failed paths: {failed_paths[:5]}") | |
| logger.error(f" ... and {len(failed_paths) - 5} more") | |
| # Do not continue if any files failed - this ensures no silent data loss | |
| raise RuntimeError(f"Failed to process {len(failed_paths)} files - aborting to prevent data loss") | |
| if not rel_paths: | |
| logger.error(f"[create_tar_part] No valid files to archive") | |
| return False | |
| # 去重并排序(可选) | |
| rel_paths = sorted(set(rel_paths)) | |
| temp_list = None | |
| try: | |
| # 2) 写入 null 分隔的列表,配合 tar 的 --null -T | |
| with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: | |
| try: | |
| file_list_content = b"\0".join(s.encode('utf-8') for s in rel_paths) | |
| f.write(file_list_content) | |
| temp_list = f.name | |
| logger.debug(f"[create_tar_part] Written {len(rel_paths)} file paths to {temp_list}") | |
| except UnicodeEncodeError as e: | |
| logger.error(f"[create_tar_part] Unicode encoding error: {e}") | |
| return False | |
| # 3) 打包(关键:-C <base> + --null -T) | |
| # We don't use -h here, we keep symlinks as symlinks | |
| # This preserves the directory structure and symlink relationships | |
| cmd = ["tar", "-cf", part_path, "-C", str(base), "--null", "-T", temp_list] | |
| logger.debug(f"[create_tar_part] Running tar command: {' '.join(cmd)}") | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600) # 1 hour timeout | |
| if result.returncode != 0: | |
| logger.error(f"[create_tar_part] tar command failed with return code {result.returncode}") | |
| logger.error(f"[create_tar_part] tar stderr: {result.stderr}") | |
| logger.error(f"[create_tar_part] tar stdout: {result.stdout}") | |
| logger.error(f"[create_tar_part] Command was: {' '.join(cmd)}") | |
| return False | |
| logger.debug(f"[create_tar_part] tar command completed successfully") | |
| # 4) 校验:验证tar文件内容 | |
| try: | |
| with tarfile.open(part_path, "r") as tf: | |
| # Count both files and symlinks | |
| members_rel = [m.name for m in tf.getmembers() if (m.isfile() or m.issym())] | |
| logger.debug(f"[create_tar_part] Found {len(members_rel)} files/symlinks in tar") | |
| except tarfile.TarError as e: | |
| logger.error(f"[create_tar_part] Failed to read tar file for verification: {e}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"[create_tar_part] Unexpected error reading tar file: {e}") | |
| return False | |
| # Basic verification - ensure we have at least some files | |
| if len(members_rel) == 0: | |
| logger.error(f"[create_tar_part] No files found in tar archive!") | |
| return False | |
| # Simple verification - we should have the same number of items | |
| expected_files = len(rel_paths) | |
| actual_files = len(members_rel) | |
| if expected_files != actual_files: | |
| logger.error(f"[create_tar_part] File count mismatch: expected {expected_files}, got {actual_files}") | |
| logger.error(f"Expected files: {rel_paths[:5]}") | |
| logger.error(f"Actual files: {members_rel[:5]}") | |
| return False | |
| # Check that all expected files are present (may have duplicates due to symlinks) | |
| expected_set = set(rel_paths) | |
| actual_set = set(members_rel) | |
| # All expected files should be in the tar (but tar may have duplicates) | |
| missing = expected_set - actual_set | |
| if missing: | |
| logger.error(f"[create_tar_part] Missing files in tar: {len(missing)}") | |
| for f in list(missing)[:5]: | |
| logger.error(f" Missing: {f}") | |
| return False | |
| logger.info(f"[create_tar_part] Successfully created and verified {os.path.basename(part_path)} with {len(members_rel)} files") | |
| return True | |
| except Exception as e: | |
| logger.exception(f"[create_tar_part] Exception: {e}") | |
| return False | |
| finally: | |
| if temp_list and os.path.exists(temp_list): | |
| os.unlink(temp_list) | |
| @ray.remote | |
| def create_tar_parts_batch(file_chunks: list[tuple[list[str], str]], base_dir: str) -> list[bool]: | |
| """Ray remote function to create multiple tar archives.""" | |
| # Re-setup logging in Ray worker | |
| from loguru import logger | |
| logger.remove() | |
| logger.add(sys.stderr, level="INFO") | |
| results = [] | |
| failed_parts = [] | |
| for i, (chunk_files, part_path) in enumerate(file_chunks): | |
| logger.info(f"Creating tar part {i+1}/{len(file_chunks)}: {os.path.basename(part_path)}") | |
| logger.info(f" Files in this part: {len(chunk_files)}") | |
| result = create_tar_part(chunk_files, part_path, base_dir) | |
| results.append(result) | |
| if not result: | |
| failed_parts.append(part_path) | |
| logger.error(f"Failed to create tar part: {part_path}") | |
| else: | |
| logger.info(f"Successfully created: {os.path.basename(part_path)}") | |
| if failed_parts: | |
| logger.error(f"Failed to create {len(failed_parts)} tar parts:") | |
| for failed_part in failed_parts: | |
| logger.error(f" {failed_part}") | |
| return results | |
| def distribute_files_simple(all_files: list[str], num_parts: int) -> list[list[str]]: | |
| """Simple file distribution - randomly shuffles and splits into equal chunks.""" | |
| start_time = time.time() | |
| random.seed(42) | |
| random.shuffle(all_files) | |
| # Simple distribution: split into equal chunks | |
| chunk_size = (len(all_files) + num_parts - 1) // num_parts | |
| parts = [all_files[i : i + chunk_size] for i in range(0, len(all_files), chunk_size)] | |
| logger.info(f"Distributed {len(all_files)} files into {len(parts)} parts in {time.time() - start_time:.4f} seconds") | |
| return parts | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Compress directory into tar.zst archive with parallel processing" | |
| ) | |
| parser.add_argument( | |
| "--jobs", | |
| "-j", | |
| type=int, | |
| default=None, | |
| help="Number of parallel jobs (default: number of CPU cores)", | |
| ) | |
| parser.add_argument( | |
| "--directory", | |
| "-d", | |
| type=str, | |
| default=".", | |
| help="Directory to compress (default: current directory)", | |
| ) | |
| parser.add_argument( | |
| "--log-dir", | |
| type=str, | |
| default="logs", | |
| help="Directory for log files (default: logs)", | |
| ) | |
| args = parser.parse_args() | |
| # Setup logging | |
| setup_logging(args.log_dir) | |
| # Initialize Ray | |
| try: | |
| ray.init(runtime_env={"pip": ["loguru"]}) | |
| logger.info("Ray initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Ray: {e}") | |
| return 1 | |
| # Determine number of jobs (for display purposes, Ray manages workers automatically) | |
| if args.jobs is None: | |
| args.jobs = os.cpu_count() or 4 | |
| logger.info(f"Starting compression with Ray parallel processing") | |
| logger.info(f"Target directory: {args.directory}") | |
| logger.info(f"Log directory: {args.log_dir}") | |
| # Convert directory to absolute path | |
| args.directory = os.path.abspath(args.directory) | |
| # Setup parts directory inside the target directory | |
| parts_dir = os.path.join(args.directory, ".archive") | |
| # If .archive already exists, remove it completely | |
| if os.path.exists(parts_dir): | |
| logger.info(f"Removing existing archive directory: {parts_dir}") | |
| shutil.rmtree(parts_dir) | |
| # Create fresh archive directory | |
| os.makedirs(parts_dir) | |
| # Get all files | |
| logger.info("Step 1: Collecting all files...") | |
| all_files = get_all_files(args.directory) | |
| if not all_files: | |
| logger.error("No files found to archive") | |
| os.rmdir(parts_dir) | |
| ray.shutdown() | |
| return 1 | |
| logger.info(f"Found {len(all_files)} files to archive") | |
| # Distribute files simply and evenly | |
| logger.info("Step 2: Distributing files into parts...") | |
| file_chunks = distribute_files_simple(all_files, num_parts=args.jobs) | |
| logger.info(f"Split {len(all_files)} files into {len(file_chunks)} parts") | |
| # Verify no files were lost during distribution | |
| total_files_in_chunks = sum(len(chunk) for chunk in file_chunks) | |
| if total_files_in_chunks != len(all_files): | |
| logger.error(f"CRITICAL: File count mismatch during distribution! Expected {len(all_files)}, got {total_files_in_chunks}") | |
| ray.shutdown() | |
| return 1 | |
| # Create partial archives | |
| logger.info("Step 3: Creating partial uncompressed archives...") | |
| file_chunks_with_paths = [ | |
| (chunk, os.path.join(parts_dir, f"part_{i}.tar")) for i, chunk in enumerate(file_chunks) | |
| ] | |
| # Split into batches for Ray processing | |
| batch_size = max(1, len(file_chunks_with_paths) // args.jobs) | |
| batches = [file_chunks_with_paths[i:i + batch_size] for i in range(0, len(file_chunks_with_paths), batch_size)] | |
| # Create Ray tasks | |
| tasks = [create_tar_parts_batch.remote(batch, args.directory) for batch in batches] | |
| # Wait for all tasks to complete | |
| try: | |
| all_results = ray.get(tasks) | |
| results = [r for batch in all_results for r in batch] | |
| except Exception as e: | |
| logger.error(f"Ray task execution failed: {e}") | |
| ray.shutdown() | |
| return 1 | |
| if not all(results): | |
| logger.error("Some parts failed to create") | |
| failed_count = sum(1 for r in results if not r) | |
| logger.error(f"Failed to create {failed_count} out of {len(results)} tar parts") | |
| ray.shutdown() | |
| return 1 | |
| # Generate manifest for integrity verification | |
| logger.info("Step 4: Generating manifest for integrity verification...") | |
| manifest = { | |
| "total_files": len(all_files), | |
| "files": all_files, | |
| "tar_parts": [f"part_{i}.tar" for i in range(len(file_chunks))], | |
| "timestamp": datetime.datetime.now().isoformat(), | |
| "source_directory": args.directory | |
| } | |
| manifest_path = os.path.join(parts_dir, "manifest.json") | |
| try: | |
| with open(manifest_path, "w") as f: | |
| json.dump(manifest, f, indent=2) | |
| logger.info(f"Manifest written to {manifest_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to write manifest: {e}") | |
| ray.shutdown() | |
| return 1 | |
| # Verify integrity by checking tar contents using Ray | |
| logger.info("Step 5: Verifying archive integrity with Ray parallel processing...") | |
| # Create Ray tasks for each tar file verification | |
| tar_paths = [os.path.join(parts_dir, f"part_{i}.tar") for i in range(len(file_chunks))] | |
| verify_tasks = [get_tar_members_remote.remote(tar_path) for tar_path in tar_paths] | |
| # Wait for all verification tasks to complete | |
| try: | |
| logger.info(f"Running {len(verify_tasks)} parallel verification tasks...") | |
| start_verify = time.time() | |
| all_tar_members = ray.get(verify_tasks) | |
| verify_time = time.time() - start_verify | |
| logger.info(f"Verification completed in {verify_time:.2f} seconds") | |
| except Exception as e: | |
| logger.error(f"Ray verification tasks failed: {e}") | |
| ray.shutdown() | |
| return 1 | |
| # Aggregate all archived files | |
| archived_files = set() | |
| for i, members in enumerate(all_tar_members): | |
| if not members: | |
| logger.error(f"Failed to verify part_{i}.tar - no members returned") | |
| ray.shutdown() | |
| return 1 | |
| archived_files.update(members) | |
| logger.debug(f"part_{i}.tar contains {len(members)} files") | |
| logger.info(f"Total unique files archived: {len(archived_files)}") | |
| # Check for missing files | |
| missing_files = set(all_files) - archived_files | |
| if missing_files: | |
| logger.error(f"CRITICAL: {len(missing_files)} files were not archived!") | |
| for f in list(missing_files)[:10]: | |
| logger.error(f" Missing: {f}") | |
| if len(missing_files) > 10: | |
| logger.error(f" ... and {len(missing_files) - 10} more files") | |
| ray.shutdown() | |
| return 1 | |
| logger.success(f"Successfully verified all {len(all_files)} files in archive") | |
| print(f"Done. Final archive is {args.directory}/.archive") | |
| # Clean shutdown | |
| ray.shutdown() | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import os | |
| import sys | |
| import shutil | |
| import tarfile | |
| import argparse | |
| import time | |
| import json | |
| from typing import List, Tuple, Optional | |
| from loguru import logger | |
| import ray | |
| def setup_logging(log_dir: str = "logs"): | |
| """Setup logging with loguru to both file and console.""" | |
| # Remove default logger | |
| logger.remove() | |
| # Create logs directory | |
| os.makedirs(log_dir, exist_ok=True) | |
| # Add console handler with INFO level | |
| logger.add( | |
| sys.stderr, | |
| format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{message}</cyan>", | |
| level="INFO" | |
| ) | |
| # File handler with DEBUG level | |
| import datetime | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| pid = os.getpid() | |
| log_filename = f"decompress_{timestamp}_pid{pid}.log" | |
| log_path = os.path.join(log_dir, log_filename) | |
| logger.add( | |
| log_path, | |
| format="{time:YYYY-MM-DD HH:mm:ss} | PID:{process} | {level: <8} | {message}", | |
| level="DEBUG", | |
| rotation="100 MB", | |
| retention="7 days" | |
| ) | |
| logger.info(f"Logging initialized. Process PID: {pid}") | |
| logger.info(f"Detailed logs will be written to: {log_path}") | |
| # ---------- Path safety utilities ---------- | |
| def _safe_join(root: str, name: str) -> str: | |
| """Join and normalize ensuring the path stays within root.""" | |
| root = os.path.abspath(root) | |
| # Reject absolute and null bytes early | |
| if os.path.isabs(name) or ("\x00" in name): | |
| raise ValueError(f"Unsafe path (absolute or contains NUL): {name}") | |
| # Normalize against root | |
| dest = os.path.normpath(os.path.join(root, name)) | |
| if not (dest == root or dest.startswith(root + os.sep)): | |
| raise ValueError(f"Unsafe path (escapes root): {name}") | |
| return dest | |
| def _is_supported_member(m: tarfile.TarInfo) -> bool: | |
| """We support regular files, directories, and symbolic links.""" | |
| return m.isreg() or m.isdir() or m.issym() or m.islnk() | |
| # ---------- Extraction core (tarfile-based) ---------- | |
| def _prescan_members(tar_path: str, output_dir: str) -> List[tarfile.TarInfo]: | |
| # Import logger locally to avoid pickle issues | |
| from loguru import logger | |
| """ | |
| Pre-scan tar members: | |
| - ensure names are safe | |
| - only keep supported types (regular files, dirs) | |
| Return the filtered, sorted (dirs first) member list. | |
| """ | |
| members: List[tarfile.TarInfo] = [] | |
| try: | |
| with tarfile.open(tar_path, mode="r:*") as tf: | |
| for m in tf.getmembers(): | |
| if not _is_supported_member(m): | |
| logger.info(f"[{tar_path}] Skipping unsupported member type: {m.name} ({m.type})") | |
| continue | |
| # Will raise on unsafe | |
| _ = _safe_join(output_dir, m.name) | |
| members.append(m) | |
| except Exception as e: | |
| raise RuntimeError(f"Pre-scan failed for {tar_path}: {e}") | |
| # Ensure directories are created before files | |
| members.sort(key=lambda x: 0 if x.isdir() else 1) | |
| return members | |
| def _extract_member(tf: tarfile.TarFile, m: tarfile.TarInfo, output_dir: str) -> Optional[str]: | |
| # Import logger locally to avoid pickle issues | |
| from loguru import logger | |
| """ | |
| Extract a single member safely. Returns destination path for files, or None for dirs. | |
| Performs size check for regular files and sets mode/mtime. | |
| """ | |
| dst = _safe_join(output_dir, m.name) | |
| if m.isdir(): | |
| os.makedirs(dst, exist_ok=True) | |
| try: | |
| os.chmod(dst, m.mode & 0o777) | |
| except Exception as e: | |
| logger.warning(f"Failed to set permissions on directory {dst}: {e}") | |
| return None | |
| if m.isreg(): | |
| # Ensure parent exists | |
| os.makedirs(os.path.dirname(dst), exist_ok=True) | |
| # Stream copy in chunks | |
| src_f = tf.extractfile(m) | |
| if src_f is None: | |
| raise RuntimeError(f"extractfile() returned None for {m.name}") | |
| # Write atomically: write to tmp then rename | |
| tmp_dst = dst + ".partial.__tmp__" | |
| try: | |
| with src_f as fin, open(tmp_dst, "wb") as fout: | |
| shutil.copyfileobj(fin, fout, length=1024 * 1024) # 1 MiB chunks | |
| # Size verification | |
| st = os.stat(tmp_dst) | |
| if st.st_size != m.size: | |
| raise RuntimeError( | |
| f"Size mismatch for {m.name}: wrote {st.st_size} vs header {m.size}" | |
| ) | |
| except Exception: | |
| # Clean up temp file on any error | |
| try: | |
| if os.path.exists(tmp_dst): | |
| os.remove(tmp_dst) | |
| except Exception as cleanup_e: | |
| logger.warning(f"Failed to clean up temp file {tmp_dst}: {cleanup_e}") | |
| raise | |
| # Set mode/mtime then rename into place | |
| try: | |
| os.chmod(tmp_dst, m.mode & 0o777) | |
| except Exception as e: | |
| logger.warning(f"Failed to set permissions on file {m.name}: {e}") | |
| try: | |
| # Set both atime and mtime to m.mtime | |
| os.utime(tmp_dst, (m.mtime, m.mtime)) | |
| except Exception as e: | |
| logger.warning(f"Failed to set timestamp on file {m.name}: {e}") | |
| # Atomic replace | |
| os.replace(tmp_dst, dst) | |
| return dst | |
| if m.issym(): | |
| # Create symbolic link | |
| os.makedirs(os.path.dirname(dst), exist_ok=True) | |
| # Remove existing file/link if it exists | |
| if os.path.lexists(dst): | |
| os.unlink(dst) | |
| # Security check: warn about absolute symlink targets | |
| # Note: User requested to follow symlinks by default, so we allow but warn | |
| if os.path.isabs(m.linkname): | |
| logger.warning(f"Symlink {m.name} has absolute target path: {m.linkname}") | |
| try: | |
| os.symlink(m.linkname, dst) | |
| logger.debug(f"Created symbolic link: {dst} -> {m.linkname}") | |
| return dst | |
| except Exception as e: | |
| logger.error(f"Failed to create symbolic link {dst} -> {m.linkname}: {e}") | |
| raise RuntimeError(f"Symbolic link creation failed for {m.name}: {e}") | |
| if m.islnk(): | |
| # Create hard link | |
| os.makedirs(os.path.dirname(dst), exist_ok=True) | |
| link_target = _safe_join(output_dir, m.linkname) | |
| # Remove existing file/link if it exists | |
| if os.path.lexists(dst): | |
| os.unlink(dst) | |
| try: | |
| os.link(link_target, dst) | |
| logger.debug(f"Created hard link: {dst} -> {link_target}") | |
| return dst | |
| except Exception as e: | |
| logger.error(f"Failed to create hard link {dst} -> {link_target}: {e}") | |
| raise RuntimeError(f"Hard link creation failed for {m.name}: {e}") | |
| # Should not reach here due to _is_supported_member | |
| return None | |
| def extract_tar_part(tar_file: str, output_dir: str, max_retries: int = 3) -> Tuple[bool, int, int]: | |
| # Import logger locally to avoid pickle issues | |
| from loguru import logger | |
| """ | |
| Extract a single tar safely using tarfile with pre-scan and per-file verification. | |
| Returns (success, expected_files, extracted_files) for validation. | |
| Retries on transient errors. | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Pre-scan for safety and to decide the plan | |
| try: | |
| members = _prescan_members(tar_file, output_dir) | |
| except Exception as e: | |
| logger.error(str(e)) | |
| return False, 0, 0 | |
| if not members: | |
| logger.info(f"[{tar_file}] No extractable members (after filtering).") | |
| return True, 0, 0 | |
| expected_files = len(members) | |
| extracted_files = 0 | |
| for attempt in range(1, max_retries + 1): | |
| extracted_files = 0 | |
| try: | |
| with tarfile.open(tar_file, mode="r:*") as tf: | |
| # Build a dict for fast lookup if needed | |
| names = {m.name: m for m in members} | |
| # Iterate exactly our filtered list to keep ordering (dirs first) | |
| for m in members: | |
| result = _extract_member(tf, names[m.name], output_dir) | |
| if result is not None: # Successfully extracted (file/link) | |
| extracted_files += 1 | |
| elif m.isdir(): # Directory creation counts as success | |
| extracted_files += 1 | |
| logger.info(f"[{os.path.basename(tar_file)}] Extracted {extracted_files}/{expected_files} items") | |
| return True, expected_files, extracted_files | |
| except Exception as e: | |
| logger.error(f"[{tar_file}] Extract attempt {attempt}/{max_retries} failed: {e}") | |
| if attempt < max_retries: | |
| time.sleep(min(2 ** attempt, 10)) | |
| else: | |
| logger.error(f"[{tar_file}] All {max_retries} attempts failed") | |
| return False, expected_files, extracted_files | |
| # ---------- Ray integration ---------- | |
| @ray.remote | |
| def extract_tar_parts_batch(tar_files: List[str], output_dir: str) -> List[Tuple[bool, int, int]]: | |
| # Re-setup logging in Ray worker | |
| from loguru import logger | |
| logger.remove() | |
| logger.add(sys.stderr, level="INFO") | |
| results = [] | |
| for tar_file in tar_files: | |
| result = extract_tar_part(tar_file, output_dir) | |
| results.append(result) | |
| return results | |
| # ---------- Cleanup utilities ---------- | |
| @ray.remote | |
| def cleanup_temp_files_remote(directory: str, subdirs: List[str]) -> int: | |
| """Ray remote function to clean up temp files in specific subdirectories.""" | |
| from loguru import logger | |
| logger.remove() | |
| logger.add(sys.stderr, level="INFO") | |
| count = 0 | |
| for subdir in subdirs: | |
| search_dir = os.path.join(directory, subdir) if subdir else directory | |
| if not os.path.exists(search_dir): | |
| continue | |
| try: | |
| for root, dirs, files in os.walk(search_dir): | |
| for file in files: | |
| if file.endswith('.partial.__tmp__'): | |
| temp_file = os.path.join(root, file) | |
| try: | |
| os.unlink(temp_file) | |
| logger.info(f"Cleaned up leftover temp file: {temp_file}") | |
| count += 1 | |
| except Exception as e: | |
| logger.warning(f"Failed to clean temp file {temp_file}: {e}") | |
| except Exception as e: | |
| logger.warning(f"Failed to scan for temp files in {search_dir}: {e}") | |
| return count | |
| def cleanup_temp_files(directory: str) -> int: | |
| # Import logger locally to avoid pickle issues | |
| from loguru import logger | |
| """Clean up any leftover .partial.__tmp__ files from previous runs.""" | |
| count = 0 | |
| try: | |
| for root, dirs, files in os.walk(directory): | |
| for file in files: | |
| if file.endswith('.partial.__tmp__'): | |
| temp_file = os.path.join(root, file) | |
| try: | |
| os.unlink(temp_file) | |
| logger.info(f"Cleaned up leftover temp file: {temp_file}") | |
| count += 1 | |
| except Exception as e: | |
| logger.warning(f"Failed to clean temp file {temp_file}: {e}") | |
| except Exception as e: | |
| logger.warning(f"Failed to scan for temp files in {directory}: {e}") | |
| return count | |
| # ---------- Verification utilities ---------- | |
| @ray.remote | |
| def get_tar_members_for_verification(tar_file: str) -> set: | |
| """Ray remote function to get file members from a tar file for verification.""" | |
| from loguru import logger | |
| logger.remove() | |
| logger.add(sys.stderr, level="INFO") | |
| extracted_files = set() | |
| try: | |
| with tarfile.open(tar_file, mode="r:*") as tf: | |
| for member in tf.getmembers(): | |
| if member.isfile() or member.issym(): | |
| extracted_files.add(member.name) | |
| logger.info(f"Found {len(extracted_files)} files in {os.path.basename(tar_file)}") | |
| except Exception as e: | |
| logger.error(f"Failed to read tar file {tar_file}: {e}") | |
| raise | |
| return extracted_files | |
| # ---------- CLI & orchestration ---------- | |
| def _discover_tar_parts(input_path: str) -> List[str]: | |
| # Import logger locally to avoid pickle issues | |
| from loguru import logger | |
| """Discover tar part files with flexible naming patterns.""" | |
| archive_dir = input_path | |
| search_dir = archive_dir if os.path.exists(archive_dir) else input_path | |
| logger.info(f"Searching for tar parts in: {search_dir}") | |
| # Multiple patterns to support different naming conventions | |
| patterns = [ | |
| (lambda f: f.startswith("part_") and f.endswith(".tar"), "part_*.tar"), | |
| (lambda f: f.endswith(".tar"), "*.tar"), | |
| ] | |
| tar_files = [] | |
| for pattern_func, pattern_name in patterns: | |
| found = [] | |
| try: | |
| for f in os.listdir(search_dir): | |
| if pattern_func(f) and os.path.isfile(os.path.join(search_dir, f)): | |
| found.append(os.path.join(search_dir, f)) | |
| except Exception as e: | |
| logger.warning(f"Failed to list directory {search_dir}: {e}") | |
| continue | |
| if found: | |
| tar_files = sorted(found) | |
| logger.info(f"Found {len(tar_files)} files matching pattern {pattern_name}") | |
| break | |
| if not tar_files: | |
| logger.warning(f"No tar files found in {search_dir}") | |
| # Try to list what's actually there | |
| try: | |
| all_files = os.listdir(search_dir) | |
| logger.info(f"Directory contains {len(all_files)} items: {all_files[:10]}...") | |
| except Exception: | |
| pass | |
| return tar_files | |
| def main() -> int: | |
| parser = argparse.ArgumentParser( | |
| description="Decompress directory with tar parts into target folder (safe tarfile version)" | |
| ) | |
| parser.add_argument("input", type=str, help="Input directory containing tar parts") | |
| parser.add_argument("output", type=str, help="Output directory where files will be extracted") | |
| parser.add_argument( | |
| "--jobs", "-j", type=int, default=None, | |
| help="Number of parallel jobs for Ray batching" | |
| ) | |
| parser.add_argument( | |
| "--log-dir", type=str, default="logs", | |
| help="Directory for log files (default: logs)" | |
| ) | |
| args = parser.parse_args() | |
| # Setup logging | |
| setup_logging(args.log_dir) | |
| # Initialize Ray | |
| try: | |
| ray.init(runtime_env={"pip": ["loguru"]}) | |
| logger.info("Ray initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Ray: {e}") | |
| return 1 | |
| if args.jobs is None: | |
| args.jobs = os.cpu_count() or 4 | |
| input_path = os.path.abspath(args.input) | |
| output_dir = os.path.abspath(args.output) | |
| if not os.path.isdir(input_path): | |
| logger.error("Input must be a directory containing tar parts") | |
| return 1 | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Clean up any leftover temporary files from previous runs | |
| temp_cleaned = cleanup_temp_files(output_dir) | |
| if temp_cleaned > 0: | |
| logger.info(f"Cleaned up {temp_cleaned} leftover temporary files") | |
| tar_files = _discover_tar_parts(input_path) | |
| if not tar_files: | |
| logger.error(f"No tar part files found in {input_path}") | |
| ray.shutdown() | |
| return 1 | |
| # Check for manifest file and verify if it exists | |
| manifest_path = os.path.join(input_path, "manifest.json") | |
| expected_files = None | |
| if os.path.exists(manifest_path): | |
| logger.info(f"Found manifest file: {manifest_path}") | |
| try: | |
| with open(manifest_path, "r") as f: | |
| manifest = json.load(f) | |
| expected_files = set(manifest.get("files", [])) | |
| logger.info(f"Manifest indicates {len(expected_files)} files should be extracted") | |
| except Exception as e: | |
| logger.warning(f"Failed to read manifest file: {e}") | |
| # Continue without manifest verification | |
| logger.info(f"Starting decompression with Ray parallel processing") | |
| logger.info(f"Found {len(tar_files)} tar parts to extract") | |
| logger.info("Extracting tar parts...") | |
| # Use Ray for parallel processing | |
| files_per_task = max(1, (len(tar_files) + args.jobs - 1) // args.jobs) | |
| batches = [tar_files[i:i + files_per_task] for i in range(0, len(tar_files), files_per_task)] | |
| # extract_tar_parts_batch(batches[0], output_dir) | |
| # exit() | |
| tasks = [extract_tar_parts_batch.remote(b, output_dir) for b in batches] | |
| all_results = ray.get(tasks) | |
| results = [r for batch in all_results for r in batch] | |
| # Analyze results and provide detailed statistics | |
| successes = [r for r in results if r[0]] | |
| failures = [r for r in results if not r[0]] | |
| total_expected = sum(r[1] for r in results) | |
| total_extracted = sum(r[2] for r in results) | |
| logger.info(f"Extraction summary:") | |
| logger.info(f" Total tar files: {len(results)}") | |
| logger.info(f" Successful: {len(successes)}") | |
| logger.info(f" Failed: {len(failures)}") | |
| logger.info(f" Total files expected: {total_expected}") | |
| logger.info(f" Total files extracted: {total_extracted}") | |
| if failures: | |
| logger.error(f"Failed to extract {len(failures)} tar files:") | |
| for i, (success, expected, extracted) in enumerate(failures): | |
| if i < 5: # Show first 5 failures | |
| logger.error(f" {tar_files[results.index((success, expected, extracted))]}: {extracted}/{expected} files") | |
| if len(failures) > 5: | |
| logger.error(f" ... and {len(failures) - 5} more failures") | |
| ray.shutdown() | |
| return 1 | |
| # Verify against manifest if available | |
| if expected_files is not None: | |
| logger.info("Verifying extracted files against manifest using Ray...") | |
| # Use Ray to collect all extracted files in parallel | |
| verify_start = time.time() | |
| verify_tasks = [get_tar_members_for_verification.remote(tar_file) for tar_file in tar_files] | |
| try: | |
| logger.info(f"Running {len(verify_tasks)} parallel verification tasks...") | |
| all_file_sets = ray.get(verify_tasks) | |
| verify_time = time.time() - verify_start | |
| logger.info(f"Verification completed in {verify_time:.2f} seconds") | |
| except Exception as e: | |
| logger.error(f"Failed to verify tar files: {e}") | |
| ray.shutdown() | |
| return 1 | |
| # Merge all file sets | |
| extracted_files = set() | |
| for file_set in all_file_sets: | |
| extracted_files.update(file_set) | |
| logger.info(f"Total unique files found: {len(extracted_files)}") | |
| # Check for missing files | |
| missing_files = expected_files - extracted_files | |
| if missing_files: | |
| logger.error(f"CRITICAL: {len(missing_files)} files in manifest were not extracted!") | |
| for f in list(missing_files)[:10]: | |
| logger.error(f" Missing: {f}") | |
| if len(missing_files) > 10: | |
| logger.error(f" ... and {len(missing_files) - 10} more files") | |
| ray.shutdown() | |
| return 1 | |
| # Check for extra files (less critical but worth noting) | |
| extra_files = extracted_files - expected_files | |
| if extra_files: | |
| logger.warning(f"Found {len(extra_files)} files not in manifest (could be directories or new files)") | |
| for f in list(extra_files)[:5]: | |
| logger.warning(f" Extra: {f}") | |
| logger.success("Manifest verification passed - all expected files were extracted") | |
| if total_extracted != total_expected: | |
| logger.error(f"CRITICAL: File count mismatch: expected {total_expected}, extracted {total_extracted}") | |
| ray.shutdown() | |
| return 1 # This is a critical failure - data integrity issue | |
| logger.success(f"Successfully extracted all parts to {output_dir}") | |
| # Clean shutdown | |
| ray.shutdown() | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment