Skip to content

Instantly share code, notes, and snippets.

@jingwangsg
Last active September 20, 2025 04:25
Show Gist options
  • Select an option

  • Save jingwangsg/0ff7eafe92fa84ed8993957683ddbe9c to your computer and use it in GitHub Desktop.

Select an option

Save jingwangsg/0ff7eafe92fa84ed8993957683ddbe9c to your computer and use it in GitHub Desktop.
ray_compress
import os
import os.path as osp
import sys
import shutil
import subprocess
import tempfile
import argparse
import random
import time
import json
import datetime
import numpy as np
from loguru import logger
try:
from tqdm import tqdm
except ImportError:
def tqdm(iterable, **kwargs):
return iterable
import ray
def setup_logging(log_dir: str = "logs"):
"""Setup logging with loguru to both file and console."""
# Remove default logger
logger.remove()
# Create logs directory
os.makedirs(log_dir, exist_ok=True)
# Add console handler with INFO level
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{message}</cyan>",
level="INFO"
)
# Add file handler with DEBUG level
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
pid = os.getpid()
log_filename = f"compress_{timestamp}_pid{pid}.log"
log_path = os.path.join(log_dir, log_filename)
logger.add(
log_path,
format="{time:YYYY-MM-DD HH:mm:ss} | PID:{process} | {level: <8} | {message}",
level="DEBUG",
rotation="100 MB",
retention="7 days"
)
logger.info(f"Logging initialized. Process PID: {pid}")
logger.info(f"Detailed logs will be written to: {log_path}")
def get_all_files(directory: str = ".") -> list[str]:
"""Get all files and symlinks in the directory recursively.
Returns both regular files and symlinks.
"""
try:
# Get all files (fd will follow symlinks with -L/--follow)
# This will recursively follow symlinked directories and find all files
result = subprocess.run(
["fd", "--type", "f", "--follow", "--base-directory", directory, "."],
capture_output=True,
text=True,
check=True,
)
files = result.stdout.strip().split("\n") if result.stdout.strip() else []
files = [f for f in files if f]
logger.info(f"Found {len(files)} files/symlinks")
return files
except subprocess.CalledProcessError as e:
logger.error(f"Error running fd command: {e.stderr}")
logger.error(f"fd command failed with return code: {e.returncode}")
raise RuntimeError(f"File discovery failed: {e}")
except FileNotFoundError:
logger.error("fd command not found. Please install fd-find package.")
logger.error("On Ubuntu: sudo apt install fd-find")
logger.error("On macOS: brew install fd")
raise RuntimeError("fd command is required but not installed")
def log_tar_contents(part_path: str, files: list[str]) -> None:
"""Log tar file contents to a temporary log file."""
part_name = os.path.basename(part_path)
log_file = part_path.replace(".tar", "_contents.log")
try:
with open(log_file, "w") as f:
f.write(f"TAR FILE: {part_name}\n")
f.write(f"CONTAINS {len(files)} FILES:\n")
f.write("-" * 50 + "\n")
for file in sorted(files):
f.write(f"{file}\n")
f.write("-" * 50 + "\n\n")
logger.info(f"Logged contents of {part_name} to {log_file}")
except Exception as e:
logger.warning(f"Failed to log contents for {part_name}: {e}")
@ray.remote
def get_tar_members_remote(tar_file: str) -> list[str]:
"""Ray remote function to get members of a single tar file."""
from loguru import logger
logger.remove()
logger.add(sys.stderr, level="INFO")
try:
logger.info(f"Verifying {os.path.basename(tar_file)}...")
cmd = f"tar -tf {tar_file}"
result = subprocess.run(cmd, capture_output=True, text=True, shell=True)
files_in_tar = result.stdout.strip().splitlines()
members = [file for file in files_in_tar if file]
logger.info(f" Found {len(members)} files in {os.path.basename(tar_file)}")
return members
except Exception as e:
logger.error(f"Failed to verify {tar_file}: {e}")
return []
import os
import tarfile
import subprocess
import tempfile
from pathlib import Path
def create_tar_part(files: list[str], part_path: str, base_dir: str) -> bool:
# Import logger locally to avoid pickle issues
from loguru import logger
"""Create a tar archive for a subset of files.
files: 相对路径列表 (relative to base_dir)
base_dir: 作为 tar 内部相对路径的根目录
"""
try:
base = Path(base_dir).resolve()
logger.debug(f"Creating tar part {os.path.basename(part_path)} with {len(files)} files")
logger.debug(f"Base directory: {base}")
except Exception as e:
logger.error(f"[create_tar_part] Failed to resolve base directory {base_dir}: {e}")
return False
# 1) 处理路径,确保都是相对于base_dir的
rel_paths = []
failed_paths = []
for p in files:
try:
# 如果是相对路径,直接使用;如果是绝对路径,转换为相对路径
if os.path.isabs(p):
p_abs = Path(p).resolve()
try:
rel = p_abs.relative_to(base)
rel_paths.append(rel.as_posix())
except ValueError:
logger.error(f"[create_tar_part] File outside base_dir: {p}")
failed_paths.append(p)
else:
# 验证相对路径文件是否存在
full_path = base / p
# Use lexists to check for symlinks too (exists() follows symlinks)
if full_path.exists() or os.path.islink(str(full_path)):
# 规范化路径,处理特殊字符
normalized_path = Path(p).as_posix()
rel_paths.append(normalized_path)
else:
logger.warning(f"[create_tar_part] File not found: {p}")
failed_paths.append(p)
except Exception as e:
logger.error(f"[create_tar_part] Error processing path {p}: {e}")
failed_paths.append(p)
if failed_paths:
logger.error(f"[create_tar_part] {len(failed_paths)} files failed path processing")
if len(failed_paths) <= 10:
for fp in failed_paths:
logger.error(f" Failed path: {fp}")
else:
logger.error(f" First 5 failed paths: {failed_paths[:5]}")
logger.error(f" ... and {len(failed_paths) - 5} more")
# Do not continue if any files failed - this ensures no silent data loss
raise RuntimeError(f"Failed to process {len(failed_paths)} files - aborting to prevent data loss")
if not rel_paths:
logger.error(f"[create_tar_part] No valid files to archive")
return False
# 去重并排序(可选)
rel_paths = sorted(set(rel_paths))
temp_list = None
try:
# 2) 写入 null 分隔的列表,配合 tar 的 --null -T
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
try:
file_list_content = b"\0".join(s.encode('utf-8') for s in rel_paths)
f.write(file_list_content)
temp_list = f.name
logger.debug(f"[create_tar_part] Written {len(rel_paths)} file paths to {temp_list}")
except UnicodeEncodeError as e:
logger.error(f"[create_tar_part] Unicode encoding error: {e}")
return False
# 3) 打包(关键:-C <base> + --null -T)
# We don't use -h here, we keep symlinks as symlinks
# This preserves the directory structure and symlink relationships
cmd = ["tar", "-cf", part_path, "-C", str(base), "--null", "-T", temp_list]
logger.debug(f"[create_tar_part] Running tar command: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600) # 1 hour timeout
if result.returncode != 0:
logger.error(f"[create_tar_part] tar command failed with return code {result.returncode}")
logger.error(f"[create_tar_part] tar stderr: {result.stderr}")
logger.error(f"[create_tar_part] tar stdout: {result.stdout}")
logger.error(f"[create_tar_part] Command was: {' '.join(cmd)}")
return False
logger.debug(f"[create_tar_part] tar command completed successfully")
# 4) 校验:验证tar文件内容
try:
with tarfile.open(part_path, "r") as tf:
# Count both files and symlinks
members_rel = [m.name for m in tf.getmembers() if (m.isfile() or m.issym())]
logger.debug(f"[create_tar_part] Found {len(members_rel)} files/symlinks in tar")
except tarfile.TarError as e:
logger.error(f"[create_tar_part] Failed to read tar file for verification: {e}")
return False
except Exception as e:
logger.error(f"[create_tar_part] Unexpected error reading tar file: {e}")
return False
# Basic verification - ensure we have at least some files
if len(members_rel) == 0:
logger.error(f"[create_tar_part] No files found in tar archive!")
return False
# Simple verification - we should have the same number of items
expected_files = len(rel_paths)
actual_files = len(members_rel)
if expected_files != actual_files:
logger.error(f"[create_tar_part] File count mismatch: expected {expected_files}, got {actual_files}")
logger.error(f"Expected files: {rel_paths[:5]}")
logger.error(f"Actual files: {members_rel[:5]}")
return False
# Check that all expected files are present (may have duplicates due to symlinks)
expected_set = set(rel_paths)
actual_set = set(members_rel)
# All expected files should be in the tar (but tar may have duplicates)
missing = expected_set - actual_set
if missing:
logger.error(f"[create_tar_part] Missing files in tar: {len(missing)}")
for f in list(missing)[:5]:
logger.error(f" Missing: {f}")
return False
logger.info(f"[create_tar_part] Successfully created and verified {os.path.basename(part_path)} with {len(members_rel)} files")
return True
except Exception as e:
logger.exception(f"[create_tar_part] Exception: {e}")
return False
finally:
if temp_list and os.path.exists(temp_list):
os.unlink(temp_list)
@ray.remote
def create_tar_parts_batch(file_chunks: list[tuple[list[str], str]], base_dir: str) -> list[bool]:
"""Ray remote function to create multiple tar archives."""
# Re-setup logging in Ray worker
from loguru import logger
logger.remove()
logger.add(sys.stderr, level="INFO")
results = []
failed_parts = []
for i, (chunk_files, part_path) in enumerate(file_chunks):
logger.info(f"Creating tar part {i+1}/{len(file_chunks)}: {os.path.basename(part_path)}")
logger.info(f" Files in this part: {len(chunk_files)}")
result = create_tar_part(chunk_files, part_path, base_dir)
results.append(result)
if not result:
failed_parts.append(part_path)
logger.error(f"Failed to create tar part: {part_path}")
else:
logger.info(f"Successfully created: {os.path.basename(part_path)}")
if failed_parts:
logger.error(f"Failed to create {len(failed_parts)} tar parts:")
for failed_part in failed_parts:
logger.error(f" {failed_part}")
return results
def distribute_files_simple(all_files: list[str], num_parts: int) -> list[list[str]]:
"""Simple file distribution - randomly shuffles and splits into equal chunks."""
start_time = time.time()
random.seed(42)
random.shuffle(all_files)
# Simple distribution: split into equal chunks
chunk_size = (len(all_files) + num_parts - 1) // num_parts
parts = [all_files[i : i + chunk_size] for i in range(0, len(all_files), chunk_size)]
logger.info(f"Distributed {len(all_files)} files into {len(parts)} parts in {time.time() - start_time:.4f} seconds")
return parts
def main():
parser = argparse.ArgumentParser(
description="Compress directory into tar.zst archive with parallel processing"
)
parser.add_argument(
"--jobs",
"-j",
type=int,
default=None,
help="Number of parallel jobs (default: number of CPU cores)",
)
parser.add_argument(
"--directory",
"-d",
type=str,
default=".",
help="Directory to compress (default: current directory)",
)
parser.add_argument(
"--log-dir",
type=str,
default="logs",
help="Directory for log files (default: logs)",
)
args = parser.parse_args()
# Setup logging
setup_logging(args.log_dir)
# Initialize Ray
try:
ray.init(runtime_env={"pip": ["loguru"]})
logger.info("Ray initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Ray: {e}")
return 1
# Determine number of jobs (for display purposes, Ray manages workers automatically)
if args.jobs is None:
args.jobs = os.cpu_count() or 4
logger.info(f"Starting compression with Ray parallel processing")
logger.info(f"Target directory: {args.directory}")
logger.info(f"Log directory: {args.log_dir}")
# Convert directory to absolute path
args.directory = os.path.abspath(args.directory)
# Setup parts directory inside the target directory
parts_dir = os.path.join(args.directory, ".archive")
# If .archive already exists, remove it completely
if os.path.exists(parts_dir):
logger.info(f"Removing existing archive directory: {parts_dir}")
shutil.rmtree(parts_dir)
# Create fresh archive directory
os.makedirs(parts_dir)
# Get all files
logger.info("Step 1: Collecting all files...")
all_files = get_all_files(args.directory)
if not all_files:
logger.error("No files found to archive")
os.rmdir(parts_dir)
ray.shutdown()
return 1
logger.info(f"Found {len(all_files)} files to archive")
# Distribute files simply and evenly
logger.info("Step 2: Distributing files into parts...")
file_chunks = distribute_files_simple(all_files, num_parts=args.jobs)
logger.info(f"Split {len(all_files)} files into {len(file_chunks)} parts")
# Verify no files were lost during distribution
total_files_in_chunks = sum(len(chunk) for chunk in file_chunks)
if total_files_in_chunks != len(all_files):
logger.error(f"CRITICAL: File count mismatch during distribution! Expected {len(all_files)}, got {total_files_in_chunks}")
ray.shutdown()
return 1
# Create partial archives
logger.info("Step 3: Creating partial uncompressed archives...")
file_chunks_with_paths = [
(chunk, os.path.join(parts_dir, f"part_{i}.tar")) for i, chunk in enumerate(file_chunks)
]
# Split into batches for Ray processing
batch_size = max(1, len(file_chunks_with_paths) // args.jobs)
batches = [file_chunks_with_paths[i:i + batch_size] for i in range(0, len(file_chunks_with_paths), batch_size)]
# Create Ray tasks
tasks = [create_tar_parts_batch.remote(batch, args.directory) for batch in batches]
# Wait for all tasks to complete
try:
all_results = ray.get(tasks)
results = [r for batch in all_results for r in batch]
except Exception as e:
logger.error(f"Ray task execution failed: {e}")
ray.shutdown()
return 1
if not all(results):
logger.error("Some parts failed to create")
failed_count = sum(1 for r in results if not r)
logger.error(f"Failed to create {failed_count} out of {len(results)} tar parts")
ray.shutdown()
return 1
# Generate manifest for integrity verification
logger.info("Step 4: Generating manifest for integrity verification...")
manifest = {
"total_files": len(all_files),
"files": all_files,
"tar_parts": [f"part_{i}.tar" for i in range(len(file_chunks))],
"timestamp": datetime.datetime.now().isoformat(),
"source_directory": args.directory
}
manifest_path = os.path.join(parts_dir, "manifest.json")
try:
with open(manifest_path, "w") as f:
json.dump(manifest, f, indent=2)
logger.info(f"Manifest written to {manifest_path}")
except Exception as e:
logger.error(f"Failed to write manifest: {e}")
ray.shutdown()
return 1
# Verify integrity by checking tar contents using Ray
logger.info("Step 5: Verifying archive integrity with Ray parallel processing...")
# Create Ray tasks for each tar file verification
tar_paths = [os.path.join(parts_dir, f"part_{i}.tar") for i in range(len(file_chunks))]
verify_tasks = [get_tar_members_remote.remote(tar_path) for tar_path in tar_paths]
# Wait for all verification tasks to complete
try:
logger.info(f"Running {len(verify_tasks)} parallel verification tasks...")
start_verify = time.time()
all_tar_members = ray.get(verify_tasks)
verify_time = time.time() - start_verify
logger.info(f"Verification completed in {verify_time:.2f} seconds")
except Exception as e:
logger.error(f"Ray verification tasks failed: {e}")
ray.shutdown()
return 1
# Aggregate all archived files
archived_files = set()
for i, members in enumerate(all_tar_members):
if not members:
logger.error(f"Failed to verify part_{i}.tar - no members returned")
ray.shutdown()
return 1
archived_files.update(members)
logger.debug(f"part_{i}.tar contains {len(members)} files")
logger.info(f"Total unique files archived: {len(archived_files)}")
# Check for missing files
missing_files = set(all_files) - archived_files
if missing_files:
logger.error(f"CRITICAL: {len(missing_files)} files were not archived!")
for f in list(missing_files)[:10]:
logger.error(f" Missing: {f}")
if len(missing_files) > 10:
logger.error(f" ... and {len(missing_files) - 10} more files")
ray.shutdown()
return 1
logger.success(f"Successfully verified all {len(all_files)} files in archive")
print(f"Done. Final archive is {args.directory}/.archive")
# Clean shutdown
ray.shutdown()
return 0
if __name__ == "__main__":
sys.exit(main())
#!/usr/bin/env python3
import os
import sys
import shutil
import tarfile
import argparse
import time
import json
from typing import List, Tuple, Optional
from loguru import logger
import ray
def setup_logging(log_dir: str = "logs"):
"""Setup logging with loguru to both file and console."""
# Remove default logger
logger.remove()
# Create logs directory
os.makedirs(log_dir, exist_ok=True)
# Add console handler with INFO level
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{message}</cyan>",
level="INFO"
)
# File handler with DEBUG level
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
pid = os.getpid()
log_filename = f"decompress_{timestamp}_pid{pid}.log"
log_path = os.path.join(log_dir, log_filename)
logger.add(
log_path,
format="{time:YYYY-MM-DD HH:mm:ss} | PID:{process} | {level: <8} | {message}",
level="DEBUG",
rotation="100 MB",
retention="7 days"
)
logger.info(f"Logging initialized. Process PID: {pid}")
logger.info(f"Detailed logs will be written to: {log_path}")
# ---------- Path safety utilities ----------
def _safe_join(root: str, name: str) -> str:
"""Join and normalize ensuring the path stays within root."""
root = os.path.abspath(root)
# Reject absolute and null bytes early
if os.path.isabs(name) or ("\x00" in name):
raise ValueError(f"Unsafe path (absolute or contains NUL): {name}")
# Normalize against root
dest = os.path.normpath(os.path.join(root, name))
if not (dest == root or dest.startswith(root + os.sep)):
raise ValueError(f"Unsafe path (escapes root): {name}")
return dest
def _is_supported_member(m: tarfile.TarInfo) -> bool:
"""We support regular files, directories, and symbolic links."""
return m.isreg() or m.isdir() or m.issym() or m.islnk()
# ---------- Extraction core (tarfile-based) ----------
def _prescan_members(tar_path: str, output_dir: str) -> List[tarfile.TarInfo]:
# Import logger locally to avoid pickle issues
from loguru import logger
"""
Pre-scan tar members:
- ensure names are safe
- only keep supported types (regular files, dirs)
Return the filtered, sorted (dirs first) member list.
"""
members: List[tarfile.TarInfo] = []
try:
with tarfile.open(tar_path, mode="r:*") as tf:
for m in tf.getmembers():
if not _is_supported_member(m):
logger.info(f"[{tar_path}] Skipping unsupported member type: {m.name} ({m.type})")
continue
# Will raise on unsafe
_ = _safe_join(output_dir, m.name)
members.append(m)
except Exception as e:
raise RuntimeError(f"Pre-scan failed for {tar_path}: {e}")
# Ensure directories are created before files
members.sort(key=lambda x: 0 if x.isdir() else 1)
return members
def _extract_member(tf: tarfile.TarFile, m: tarfile.TarInfo, output_dir: str) -> Optional[str]:
# Import logger locally to avoid pickle issues
from loguru import logger
"""
Extract a single member safely. Returns destination path for files, or None for dirs.
Performs size check for regular files and sets mode/mtime.
"""
dst = _safe_join(output_dir, m.name)
if m.isdir():
os.makedirs(dst, exist_ok=True)
try:
os.chmod(dst, m.mode & 0o777)
except Exception as e:
logger.warning(f"Failed to set permissions on directory {dst}: {e}")
return None
if m.isreg():
# Ensure parent exists
os.makedirs(os.path.dirname(dst), exist_ok=True)
# Stream copy in chunks
src_f = tf.extractfile(m)
if src_f is None:
raise RuntimeError(f"extractfile() returned None for {m.name}")
# Write atomically: write to tmp then rename
tmp_dst = dst + ".partial.__tmp__"
try:
with src_f as fin, open(tmp_dst, "wb") as fout:
shutil.copyfileobj(fin, fout, length=1024 * 1024) # 1 MiB chunks
# Size verification
st = os.stat(tmp_dst)
if st.st_size != m.size:
raise RuntimeError(
f"Size mismatch for {m.name}: wrote {st.st_size} vs header {m.size}"
)
except Exception:
# Clean up temp file on any error
try:
if os.path.exists(tmp_dst):
os.remove(tmp_dst)
except Exception as cleanup_e:
logger.warning(f"Failed to clean up temp file {tmp_dst}: {cleanup_e}")
raise
# Set mode/mtime then rename into place
try:
os.chmod(tmp_dst, m.mode & 0o777)
except Exception as e:
logger.warning(f"Failed to set permissions on file {m.name}: {e}")
try:
# Set both atime and mtime to m.mtime
os.utime(tmp_dst, (m.mtime, m.mtime))
except Exception as e:
logger.warning(f"Failed to set timestamp on file {m.name}: {e}")
# Atomic replace
os.replace(tmp_dst, dst)
return dst
if m.issym():
# Create symbolic link
os.makedirs(os.path.dirname(dst), exist_ok=True)
# Remove existing file/link if it exists
if os.path.lexists(dst):
os.unlink(dst)
# Security check: warn about absolute symlink targets
# Note: User requested to follow symlinks by default, so we allow but warn
if os.path.isabs(m.linkname):
logger.warning(f"Symlink {m.name} has absolute target path: {m.linkname}")
try:
os.symlink(m.linkname, dst)
logger.debug(f"Created symbolic link: {dst} -> {m.linkname}")
return dst
except Exception as e:
logger.error(f"Failed to create symbolic link {dst} -> {m.linkname}: {e}")
raise RuntimeError(f"Symbolic link creation failed for {m.name}: {e}")
if m.islnk():
# Create hard link
os.makedirs(os.path.dirname(dst), exist_ok=True)
link_target = _safe_join(output_dir, m.linkname)
# Remove existing file/link if it exists
if os.path.lexists(dst):
os.unlink(dst)
try:
os.link(link_target, dst)
logger.debug(f"Created hard link: {dst} -> {link_target}")
return dst
except Exception as e:
logger.error(f"Failed to create hard link {dst} -> {link_target}: {e}")
raise RuntimeError(f"Hard link creation failed for {m.name}: {e}")
# Should not reach here due to _is_supported_member
return None
def extract_tar_part(tar_file: str, output_dir: str, max_retries: int = 3) -> Tuple[bool, int, int]:
# Import logger locally to avoid pickle issues
from loguru import logger
"""
Extract a single tar safely using tarfile with pre-scan and per-file verification.
Returns (success, expected_files, extracted_files) for validation.
Retries on transient errors.
"""
os.makedirs(output_dir, exist_ok=True)
# Pre-scan for safety and to decide the plan
try:
members = _prescan_members(tar_file, output_dir)
except Exception as e:
logger.error(str(e))
return False, 0, 0
if not members:
logger.info(f"[{tar_file}] No extractable members (after filtering).")
return True, 0, 0
expected_files = len(members)
extracted_files = 0
for attempt in range(1, max_retries + 1):
extracted_files = 0
try:
with tarfile.open(tar_file, mode="r:*") as tf:
# Build a dict for fast lookup if needed
names = {m.name: m for m in members}
# Iterate exactly our filtered list to keep ordering (dirs first)
for m in members:
result = _extract_member(tf, names[m.name], output_dir)
if result is not None: # Successfully extracted (file/link)
extracted_files += 1
elif m.isdir(): # Directory creation counts as success
extracted_files += 1
logger.info(f"[{os.path.basename(tar_file)}] Extracted {extracted_files}/{expected_files} items")
return True, expected_files, extracted_files
except Exception as e:
logger.error(f"[{tar_file}] Extract attempt {attempt}/{max_retries} failed: {e}")
if attempt < max_retries:
time.sleep(min(2 ** attempt, 10))
else:
logger.error(f"[{tar_file}] All {max_retries} attempts failed")
return False, expected_files, extracted_files
# ---------- Ray integration ----------
@ray.remote
def extract_tar_parts_batch(tar_files: List[str], output_dir: str) -> List[Tuple[bool, int, int]]:
# Re-setup logging in Ray worker
from loguru import logger
logger.remove()
logger.add(sys.stderr, level="INFO")
results = []
for tar_file in tar_files:
result = extract_tar_part(tar_file, output_dir)
results.append(result)
return results
# ---------- Cleanup utilities ----------
@ray.remote
def cleanup_temp_files_remote(directory: str, subdirs: List[str]) -> int:
"""Ray remote function to clean up temp files in specific subdirectories."""
from loguru import logger
logger.remove()
logger.add(sys.stderr, level="INFO")
count = 0
for subdir in subdirs:
search_dir = os.path.join(directory, subdir) if subdir else directory
if not os.path.exists(search_dir):
continue
try:
for root, dirs, files in os.walk(search_dir):
for file in files:
if file.endswith('.partial.__tmp__'):
temp_file = os.path.join(root, file)
try:
os.unlink(temp_file)
logger.info(f"Cleaned up leftover temp file: {temp_file}")
count += 1
except Exception as e:
logger.warning(f"Failed to clean temp file {temp_file}: {e}")
except Exception as e:
logger.warning(f"Failed to scan for temp files in {search_dir}: {e}")
return count
def cleanup_temp_files(directory: str) -> int:
# Import logger locally to avoid pickle issues
from loguru import logger
"""Clean up any leftover .partial.__tmp__ files from previous runs."""
count = 0
try:
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.partial.__tmp__'):
temp_file = os.path.join(root, file)
try:
os.unlink(temp_file)
logger.info(f"Cleaned up leftover temp file: {temp_file}")
count += 1
except Exception as e:
logger.warning(f"Failed to clean temp file {temp_file}: {e}")
except Exception as e:
logger.warning(f"Failed to scan for temp files in {directory}: {e}")
return count
# ---------- Verification utilities ----------
@ray.remote
def get_tar_members_for_verification(tar_file: str) -> set:
"""Ray remote function to get file members from a tar file for verification."""
from loguru import logger
logger.remove()
logger.add(sys.stderr, level="INFO")
extracted_files = set()
try:
with tarfile.open(tar_file, mode="r:*") as tf:
for member in tf.getmembers():
if member.isfile() or member.issym():
extracted_files.add(member.name)
logger.info(f"Found {len(extracted_files)} files in {os.path.basename(tar_file)}")
except Exception as e:
logger.error(f"Failed to read tar file {tar_file}: {e}")
raise
return extracted_files
# ---------- CLI & orchestration ----------
def _discover_tar_parts(input_path: str) -> List[str]:
# Import logger locally to avoid pickle issues
from loguru import logger
"""Discover tar part files with flexible naming patterns."""
archive_dir = input_path
search_dir = archive_dir if os.path.exists(archive_dir) else input_path
logger.info(f"Searching for tar parts in: {search_dir}")
# Multiple patterns to support different naming conventions
patterns = [
(lambda f: f.startswith("part_") and f.endswith(".tar"), "part_*.tar"),
(lambda f: f.endswith(".tar"), "*.tar"),
]
tar_files = []
for pattern_func, pattern_name in patterns:
found = []
try:
for f in os.listdir(search_dir):
if pattern_func(f) and os.path.isfile(os.path.join(search_dir, f)):
found.append(os.path.join(search_dir, f))
except Exception as e:
logger.warning(f"Failed to list directory {search_dir}: {e}")
continue
if found:
tar_files = sorted(found)
logger.info(f"Found {len(tar_files)} files matching pattern {pattern_name}")
break
if not tar_files:
logger.warning(f"No tar files found in {search_dir}")
# Try to list what's actually there
try:
all_files = os.listdir(search_dir)
logger.info(f"Directory contains {len(all_files)} items: {all_files[:10]}...")
except Exception:
pass
return tar_files
def main() -> int:
parser = argparse.ArgumentParser(
description="Decompress directory with tar parts into target folder (safe tarfile version)"
)
parser.add_argument("input", type=str, help="Input directory containing tar parts")
parser.add_argument("output", type=str, help="Output directory where files will be extracted")
parser.add_argument(
"--jobs", "-j", type=int, default=None,
help="Number of parallel jobs for Ray batching"
)
parser.add_argument(
"--log-dir", type=str, default="logs",
help="Directory for log files (default: logs)"
)
args = parser.parse_args()
# Setup logging
setup_logging(args.log_dir)
# Initialize Ray
try:
ray.init(runtime_env={"pip": ["loguru"]})
logger.info("Ray initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Ray: {e}")
return 1
if args.jobs is None:
args.jobs = os.cpu_count() or 4
input_path = os.path.abspath(args.input)
output_dir = os.path.abspath(args.output)
if not os.path.isdir(input_path):
logger.error("Input must be a directory containing tar parts")
return 1
os.makedirs(output_dir, exist_ok=True)
# Clean up any leftover temporary files from previous runs
temp_cleaned = cleanup_temp_files(output_dir)
if temp_cleaned > 0:
logger.info(f"Cleaned up {temp_cleaned} leftover temporary files")
tar_files = _discover_tar_parts(input_path)
if not tar_files:
logger.error(f"No tar part files found in {input_path}")
ray.shutdown()
return 1
# Check for manifest file and verify if it exists
manifest_path = os.path.join(input_path, "manifest.json")
expected_files = None
if os.path.exists(manifest_path):
logger.info(f"Found manifest file: {manifest_path}")
try:
with open(manifest_path, "r") as f:
manifest = json.load(f)
expected_files = set(manifest.get("files", []))
logger.info(f"Manifest indicates {len(expected_files)} files should be extracted")
except Exception as e:
logger.warning(f"Failed to read manifest file: {e}")
# Continue without manifest verification
logger.info(f"Starting decompression with Ray parallel processing")
logger.info(f"Found {len(tar_files)} tar parts to extract")
logger.info("Extracting tar parts...")
# Use Ray for parallel processing
files_per_task = max(1, (len(tar_files) + args.jobs - 1) // args.jobs)
batches = [tar_files[i:i + files_per_task] for i in range(0, len(tar_files), files_per_task)]
# extract_tar_parts_batch(batches[0], output_dir)
# exit()
tasks = [extract_tar_parts_batch.remote(b, output_dir) for b in batches]
all_results = ray.get(tasks)
results = [r for batch in all_results for r in batch]
# Analyze results and provide detailed statistics
successes = [r for r in results if r[0]]
failures = [r for r in results if not r[0]]
total_expected = sum(r[1] for r in results)
total_extracted = sum(r[2] for r in results)
logger.info(f"Extraction summary:")
logger.info(f" Total tar files: {len(results)}")
logger.info(f" Successful: {len(successes)}")
logger.info(f" Failed: {len(failures)}")
logger.info(f" Total files expected: {total_expected}")
logger.info(f" Total files extracted: {total_extracted}")
if failures:
logger.error(f"Failed to extract {len(failures)} tar files:")
for i, (success, expected, extracted) in enumerate(failures):
if i < 5: # Show first 5 failures
logger.error(f" {tar_files[results.index((success, expected, extracted))]}: {extracted}/{expected} files")
if len(failures) > 5:
logger.error(f" ... and {len(failures) - 5} more failures")
ray.shutdown()
return 1
# Verify against manifest if available
if expected_files is not None:
logger.info("Verifying extracted files against manifest using Ray...")
# Use Ray to collect all extracted files in parallel
verify_start = time.time()
verify_tasks = [get_tar_members_for_verification.remote(tar_file) for tar_file in tar_files]
try:
logger.info(f"Running {len(verify_tasks)} parallel verification tasks...")
all_file_sets = ray.get(verify_tasks)
verify_time = time.time() - verify_start
logger.info(f"Verification completed in {verify_time:.2f} seconds")
except Exception as e:
logger.error(f"Failed to verify tar files: {e}")
ray.shutdown()
return 1
# Merge all file sets
extracted_files = set()
for file_set in all_file_sets:
extracted_files.update(file_set)
logger.info(f"Total unique files found: {len(extracted_files)}")
# Check for missing files
missing_files = expected_files - extracted_files
if missing_files:
logger.error(f"CRITICAL: {len(missing_files)} files in manifest were not extracted!")
for f in list(missing_files)[:10]:
logger.error(f" Missing: {f}")
if len(missing_files) > 10:
logger.error(f" ... and {len(missing_files) - 10} more files")
ray.shutdown()
return 1
# Check for extra files (less critical but worth noting)
extra_files = extracted_files - expected_files
if extra_files:
logger.warning(f"Found {len(extra_files)} files not in manifest (could be directories or new files)")
for f in list(extra_files)[:5]:
logger.warning(f" Extra: {f}")
logger.success("Manifest verification passed - all expected files were extracted")
if total_extracted != total_expected:
logger.error(f"CRITICAL: File count mismatch: expected {total_expected}, extracted {total_extracted}")
ray.shutdown()
return 1 # This is a critical failure - data integrity issue
logger.success(f"Successfully extracted all parts to {output_dir}")
# Clean shutdown
ray.shutdown()
return 0
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment