Created
December 17, 2025 12:28
-
-
Save andreaskoepf/7ec556e5e866d960bc06c93aa06da7c4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| FIVER Pilot v8.3: NumPy 2.0 Compatible | |
| Target: 1 Hour | Output: Clean Imitation Learning Data | |
| Fixes: Updates type checking for modern NumPy versions. | |
| source: https://huggingface.co/datasets/sanskxr02/fidelity/commit/61102197dda8967021e7b18d11625d532e4d2c07 | |
| """ | |
| import os | |
| import cv2 | |
| import json | |
| import uuid | |
| import time | |
| import tarfile | |
| import torch | |
| import gc | |
| import shutil | |
| import math | |
| import numpy as np | |
| from dataclasses import dataclass, asdict | |
| from typing import List, Optional, Dict, Tuple, Any | |
| from tqdm import tqdm | |
| # --- LIBRARIES --- | |
| from huggingface_hub import login, hf_hub_download | |
| from ultralytics import YOLO | |
| from scipy.spatial.distance import euclidean | |
| # --- CONFIGURATION --- | |
| CONFIG = { | |
| "target_duration_sec": 3600.0, | |
| "data_dir": "./fiver_temp_data", | |
| "device": "cpu", | |
| "batch_size": 4, | |
| "sample_fps": 3, | |
| "native_fps": 30, | |
| "checkpoint_file": "fiver_v8_checkpoint.json", | |
| "output_file": "fiver_v8_clean_output.json", | |
| # HEURISTICS | |
| "interaction_dist": 0.15, | |
| "merge_gap": 0.5, | |
| "min_duration": 0.3, | |
| "motion_thresh": 0.05 | |
| } | |
| # --- UTILITIES: RECURSIVE SANITIZER (NUMPY 2.0 FIX) --- | |
| def recursive_sanitize(obj): | |
| """ | |
| Recursively walks through lists/dicts and converts | |
| numpy types to native python types. | |
| Compatible with NumPy 2.0+. | |
| """ | |
| if isinstance(obj, dict): | |
| return {k: recursive_sanitize(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [recursive_sanitize(v) for v in obj] | |
| elif isinstance(obj, tuple): | |
| return tuple(recursive_sanitize(v) for v in obj) | |
| elif isinstance(obj, (np.integer, int)): # Catch all numpy ints | |
| return int(obj) | |
| elif isinstance(obj, (np.floating, float)): # Catch all numpy floats | |
| return float(obj) | |
| elif isinstance(obj, np.ndarray): | |
| return recursive_sanitize(obj.tolist()) | |
| else: | |
| return obj | |
| # --- DATA MODELS --- | |
| @dataclass | |
| class RawEvent: | |
| timestamp: float | |
| hand: str | |
| object_id: int | |
| wrist_pos: Tuple[float, float] | |
| obj_pos: Tuple[float, float] | |
| obj_cat: int | |
| raw_conf: float | |
| @dataclass | |
| class CleanAction: | |
| action_id: str | |
| action_type: str | |
| start_time: float | |
| end_time: float | |
| hand: str | |
| object_id: Optional[int] | |
| confidence: float | |
| start_pos: Optional[Tuple[float, float]] = None | |
| end_pos: Optional[Tuple[float, float]] = None | |
| @dataclass | |
| class StateTransition: | |
| action_id: str | |
| object_id: int | |
| pre_state: Dict[str, Any] | |
| post_state: Dict[str, Any] | |
| # --- CHECKPOINT MANAGER --- | |
| class CheckpointManager: | |
| def __init__(self, filepath): | |
| self.filepath = filepath | |
| self.data = { | |
| "processed_keys": [], | |
| "total_processed_seconds": 0.0, | |
| "actions": [], | |
| "state_transitions": [], | |
| "objects": {} | |
| } | |
| self.load() | |
| def load(self): | |
| if os.path.exists(self.filepath): | |
| try: | |
| with open(self.filepath, 'r') as f: self.data.update(json.load(f)) | |
| except: pass | |
| def save(self): | |
| # NUCLEAR OPTION: Sanitize everything before dumping | |
| clean_data = recursive_sanitize(self.data) | |
| with open(self.filepath + ".tmp", 'w') as f: | |
| json.dump(clean_data, f, indent=2) | |
| os.replace(self.filepath + ".tmp", self.filepath) | |
| def is_processed(self, key): | |
| return key in self.data["processed_keys"] | |
| def add_result(self, key, duration, actions, transitions, objects): | |
| if key not in self.data["processed_keys"]: | |
| self.data["processed_keys"].append(key) | |
| self.data["total_processed_seconds"] += duration | |
| self.data["actions"].extend([asdict(a) for a in actions]) | |
| self.data["state_transitions"].extend([asdict(t) for t in transitions]) | |
| for k, v in objects.items(): | |
| self.data["objects"][str(k)] = v | |
| self.save() | |
| # --- POST-PROCESSING LOGIC --- | |
| class TrajectoryRefiner: | |
| def __init__(self): | |
| pass | |
| def get_grid_loc(self, pos: Tuple[float, float]) -> str: | |
| if not pos: return "unknown" | |
| # Ensure we are working with standard floats | |
| x, y = float(pos[0]), float(pos[1]) | |
| c = min(int(x * 3), 2) | |
| r = min(int(y * 3), 2) | |
| return f"loc_{r}_{c}" | |
| def compute_confidence(self, duration, dist_moved): | |
| dur_score = min(1.0, duration / 2.0) | |
| motion_score = 1.0 if dist_moved > CONFIG["motion_thresh"] else 0.5 | |
| conf = 0.3 + (0.4 * dur_score) + (0.2 * motion_score) | |
| return float(round(max(0.3, min(0.95, conf)), 2)) | |
| def process(self, raw_events: List[RawEvent]) -> Tuple[List[CleanAction], List[StateTransition]]: | |
| if not raw_events: return [], [] | |
| # Sort by time | |
| raw_events.sort(key=lambda x: x.timestamp) | |
| hand_timelines = {"left": [], "right": []} | |
| for e in raw_events: | |
| hand_timelines[e.hand].append(e) | |
| clean_actions = [] | |
| transitions = [] | |
| for hand, events in hand_timelines.items(): | |
| if not events: continue | |
| # FIX 1: Temporal Collapsing | |
| merged_segments = [] | |
| current_seg = { | |
| "obj": events[0].object_id, | |
| "start": events[0].timestamp, | |
| "end": events[0].timestamp, | |
| "start_pos": events[0].obj_pos, | |
| "end_pos": events[0].obj_pos, | |
| "count": 1 | |
| } | |
| for i in range(1, len(events)): | |
| e = events[i] | |
| gap = e.timestamp - current_seg["end"] | |
| if e.object_id == current_seg["obj"] and gap < CONFIG["merge_gap"]: | |
| current_seg["end"] = e.timestamp | |
| current_seg["end_pos"] = e.obj_pos | |
| current_seg["count"] += 1 | |
| else: | |
| merged_segments.append(current_seg) | |
| current_seg = { | |
| "obj": e.object_id, | |
| "start": e.timestamp, | |
| "end": e.timestamp, | |
| "start_pos": e.obj_pos, | |
| "end_pos": e.obj_pos, | |
| "count": 1 | |
| } | |
| merged_segments.append(current_seg) | |
| # FIX 2 & 3: Relabeling & Grammar | |
| for seg in merged_segments: | |
| duration = seg["end"] - seg["start"] | |
| if duration < CONFIG["min_duration"]: continue | |
| dist = euclidean(seg["start_pos"], seg["end_pos"]) | |
| core_type = "move" if dist > CONFIG["motion_thresh"] else "grasp" | |
| conf = self.compute_confidence(duration, dist) | |
| # Synthetic REACH | |
| reach_start = max(0.0, seg["start"] - 0.5) | |
| clean_actions.append(CleanAction( | |
| action_id=str(uuid.uuid4())[:8], | |
| action_type="reach", | |
| start_time=reach_start, | |
| end_time=seg["start"], | |
| hand=hand, | |
| object_id=seg["obj"], | |
| confidence=0.8, | |
| start_pos=None, end_pos=None | |
| )) | |
| # Core Action | |
| main_act_id = str(uuid.uuid4())[:8] | |
| clean_actions.append(CleanAction( | |
| action_id=main_act_id, | |
| action_type=core_type, | |
| start_time=seg["start"], | |
| end_time=seg["end"], | |
| hand=hand, | |
| object_id=seg["obj"], | |
| confidence=conf, | |
| start_pos=seg["start_pos"], | |
| end_pos=seg["end_pos"] | |
| )) | |
| # Synthetic PLACE | |
| place_id = str(uuid.uuid4())[:8] | |
| place_end = seg["end"] + 0.2 | |
| clean_actions.append(CleanAction( | |
| action_id=place_id, | |
| action_type="place", | |
| start_time=seg["end"], | |
| end_time=place_end, | |
| hand=hand, | |
| object_id=seg["obj"], | |
| confidence=0.9, | |
| start_pos=seg["end_pos"], | |
| end_pos=seg["end_pos"] | |
| )) | |
| # State Transitions (Terminal) | |
| transitions.append(StateTransition( | |
| action_id=place_id, | |
| object_id=seg["obj"], | |
| pre_state={"holder": hand, "location": self.get_grid_loc(seg["start_pos"])}, | |
| post_state={"holder": None, "location": self.get_grid_loc(seg["end_pos"])} | |
| )) | |
| return clean_actions, transitions | |
| # --- PERCEPTION ENGINE --- | |
| class PerceptionEngine: | |
| def __init__(self): | |
| print(f" [Engine] Loading Models...") | |
| self.obj_model = YOLO("yolov8n.pt") | |
| self.pose_model = YOLO("yolov8n-pose.pt") | |
| def track_clip(self, video_path, pbar): | |
| vid_stride = int(CONFIG["native_fps"] / CONFIG["sample_fps"]) | |
| results = self.obj_model.track(source=video_path, vid_stride=vid_stride, stream=True, verbose=False, persist=True, tracker="bytetrack.yaml", device='cpu') | |
| batch_frames = [] | |
| batch_obj_res = [] | |
| for res in results: | |
| batch_frames.append(res.orig_img) | |
| batch_obj_res.append(res) | |
| pbar.update(1) | |
| if len(batch_frames) >= CONFIG["batch_size"]: | |
| yield self._detect_batch(batch_frames, batch_obj_res) | |
| batch_frames = [] | |
| batch_obj_res = [] | |
| if batch_frames: | |
| yield self._detect_batch(batch_frames, batch_obj_res) | |
| def _detect_batch(self, frames, obj_results): | |
| pose_results = self.pose_model(frames, verbose=False, device='cpu', conf=0.40) | |
| batch_events = [] | |
| for obj_res, pose_res in zip(obj_results, pose_results): | |
| frame_events = [] | |
| h, w = obj_res.orig_shape | |
| # Map Objects | |
| objects = {} | |
| if obj_res.boxes.id is not None: | |
| boxes = obj_res.boxes.xyxy.numpy() | |
| ids = obj_res.boxes.id.int().numpy() | |
| clss = obj_res.boxes.cls.int().numpy() | |
| for b, i, c in zip(boxes, ids, clss): | |
| # Explicit casting to float/int | |
| cx = float((b[0]+b[2])/2/w) | |
| cy = float((b[1]+b[3])/2/h) | |
| objects[int(i)] = {"cat": int(c), "pos": (cx, cy)} | |
| # Map Hands | |
| if pose_results[0].keypoints is not None: | |
| kps = pose_results[0].keypoints.xyn.numpy() | |
| if len(kps) > 0: | |
| for side, idx in [("left", 9), ("right", 10)]: | |
| if kps[0][idx][0] == 0: continue | |
| # Explicit casting | |
| wrist = (float(kps[0][idx][0]), float(kps[0][idx][1])) | |
| closest_id = None | |
| min_dist = CONFIG["interaction_dist"] | |
| for oid, odata in objects.items(): | |
| d = euclidean(wrist, odata["pos"]) | |
| if d < min_dist: | |
| min_dist = d | |
| closest_id = oid | |
| if closest_id is not None: | |
| frame_events.append({ | |
| "hand": side, | |
| "object_id": closest_id, | |
| "wrist_pos": wrist, | |
| "obj_pos": objects[closest_id]["pos"], | |
| "obj_cat": objects[closest_id]["cat"], | |
| "conf": 1.0 | |
| }) | |
| batch_events.append(frame_events) | |
| return batch_events | |
| # --- MAIN PIPELINE --- | |
| def main(): | |
| print("--- FIVER PILOT v8.3: NUMPY 2.0 COMPATIBLE ---") | |
| token = input("π Enter Hugging Face Write Token (optional if local data exists): ").strip() | |
| if token: login(token=token) | |
| cp = CheckpointManager(CONFIG["checkpoint_file"]) | |
| engine = PerceptionEngine() | |
| refiner = TrajectoryRefiner() | |
| video_queue = [] | |
| if os.path.exists(CONFIG["data_dir"]): | |
| for root, _, files in os.walk(CONFIG["data_dir"]): | |
| for f in files: | |
| if f.endswith(".mp4"): video_queue.append(os.path.join(root, f)) | |
| if not video_queue and token: | |
| print("β¬οΈ Downloading Part 000...") | |
| try: | |
| tar_path = hf_hub_download(repo_id="builddotai/Egocentric-100K", filename="factory001/worker001/part000.tar", repo_type="dataset", local_dir=CONFIG["data_dir"]) | |
| with tarfile.open(tar_path) as tar: tar.extractall(path=os.path.join(CONFIG["data_dir"], "part000")) | |
| video_queue = glob.glob(os.path.join(CONFIG["data_dir"], "**/*.mp4"), recursive=True) | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| return | |
| video_queue.sort() | |
| pbar_global = tqdm(total=CONFIG["target_duration_sec"], initial=cp.data["total_processed_seconds"], unit="sec", desc="π Progress", colour="green") | |
| for video_path in video_queue: | |
| if cp.data["total_processed_seconds"] >= CONFIG["target_duration_sec"]: break | |
| key = os.path.basename(video_path) | |
| if cp.is_processed(key): continue | |
| # 1. RAW EXTRACTION | |
| raw_events_buffer = [] | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
| cap.release() | |
| stride = int(CONFIG["native_fps"] / CONFIG["sample_fps"]) | |
| pbar_clip = tqdm(total=total_frames//stride, desc=f"π¬ {key}", unit="frm", leave=False, colour="blue") | |
| accumulated_time = cp.data["total_processed_seconds"] | |
| frame_idx = 0 | |
| obj_registry = {} | |
| try: | |
| for batch_data in engine.track_clip(video_path, pbar_clip): | |
| for frame_evts in batch_data: | |
| current_ts = accumulated_time + (frame_idx / CONFIG["sample_fps"]) | |
| for evt in frame_evts: | |
| raw_events_buffer.append(RawEvent( | |
| timestamp=float(current_ts), # Explicit float | |
| hand=str(evt["hand"]), | |
| object_id=int(evt["object_id"]), | |
| wrist_pos=evt["wrist_pos"], | |
| obj_pos=evt["obj_pos"], | |
| obj_cat=int(evt["obj_cat"]), | |
| raw_conf=float(evt["conf"]) | |
| )) | |
| obj_registry[evt["object_id"]] = {"id": int(evt["object_id"]), "cat": int(evt["obj_cat"])} | |
| frame_idx += 1 | |
| pbar_clip.close() | |
| # 2. REFINEMENT & SAVE | |
| clean_actions, transitions = refiner.process(raw_events_buffer) | |
| duration = total_frames / fps | |
| pbar_global.update(duration) | |
| # Add result calls the sanitizer internally via save() | |
| cp.add_result(key, duration, clean_actions, transitions, obj_registry) | |
| except Exception as e: | |
| print(f"Error on {key}: {e}") | |
| gc.collect() | |
| pbar_global.close() | |
| # Final Write | |
| with open(CONFIG["output_file"], 'w') as f: | |
| json.dump(recursive_sanitize(cp.data), f, indent=2) | |
| print(f"\nβ DONE. Output saved to {CONFIG['output_file']}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment