Created
January 8, 2026 09:06
-
-
Save Qu3tzal/bdfd34892292877351726ed2204cb45e to your computer and use it in GitHub Desktop.
Convert BridgeV2 TFDS dataset to LeRobot compatible dataset.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import logging | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import tensorflow_datasets as tfds | |
| from lerobot.common.datasets.lerobot_dataset import LeRobotDataset | |
| BRIDGEV2_SHARDS = 1024 | |
| BRIDGEV2_FPS = 5 | |
| BRIDGEV2_ROBOT_TYPE = "WidowX" | |
| # Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema | |
| BRIDGEV2_FEATURES = { | |
| "has_language": { | |
| "dtype": "bool", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| "has_img_0": { | |
| "dtype": "bool", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| "has_img_1": { | |
| "dtype": "bool", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| "has_img_2": { | |
| "dtype": "bool", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| "has_img_3": { | |
| "dtype": "bool", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| # true on first step of the episode | |
| "is_first": { | |
| "dtype": "bool", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| # true on last step of the episode | |
| "is_last": { | |
| "dtype": "bool", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| # true on last step of the episode if it is a terminal step, True for demos | |
| "is_terminal": { | |
| "dtype": "bool", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| # language_instruction is also stored as "task" to follow LeRobot standard | |
| "language_instruction": { | |
| "dtype": "string", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| "observation.state": { | |
| "dtype": "float32", | |
| "shape": (7,), | |
| "names": { | |
| "axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"], | |
| }, | |
| }, | |
| "observation.images.image_0": { | |
| "dtype": "image", | |
| "shape": (256, 256, 3), | |
| "names": [ | |
| "height", | |
| "width", | |
| "channels", | |
| ], | |
| }, | |
| "observation.images.image_1": { | |
| "dtype": "image", | |
| "shape": (256, 256, 3), | |
| "names": [ | |
| "height", | |
| "width", | |
| "channels", | |
| ], | |
| }, | |
| "observation.images.image_2": { | |
| "dtype": "image", | |
| "shape": (256, 256, 3), | |
| "names": [ | |
| "height", | |
| "width", | |
| "channels", | |
| ], | |
| }, | |
| "observation.images.image_3": { | |
| "dtype": "image", | |
| "shape": (256, 256, 3), | |
| "names": [ | |
| "height", | |
| "width", | |
| "channels", | |
| ], | |
| }, | |
| "action": { | |
| "dtype": "float32", | |
| "shape": (7,), | |
| "names": { | |
| "axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"], | |
| }, | |
| }, | |
| "discount": { | |
| "dtype": "float32", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| "reward": { | |
| "dtype": "float32", | |
| "shape": (1,), | |
| "names": None, | |
| }, | |
| } | |
| def generate_lerobot_frames(tf_episode): | |
| m = tf_episode["episode_metadata"] | |
| frame_meta = { | |
| "has_img_0": np.array([m["has_image_0"].numpy()]), | |
| "has_img_1": np.array([m["has_image_1"].numpy()]), | |
| "has_img_2": np.array([m["has_image_2"].numpy()]), | |
| "has_img_3": np.array([m["has_image_3"].numpy()]), | |
| "has_language": np.array([m["has_language"].numpy()]), | |
| } | |
| for f in tf_episode["steps"]: | |
| frame = { | |
| "is_first": np.array([f["is_first"].numpy()]), | |
| "is_last": np.array([f["is_last"].numpy()]), | |
| "is_terminal": np.array([f["is_terminal"].numpy()]), | |
| "language_instruction": f["language_instruction"].numpy().decode(), | |
| "observation.state": f["observation"]["state"].numpy(), | |
| "observation.images.image_0": f["observation"]["image_0"].numpy(), | |
| "observation.images.image_1": f["observation"]["image_1"].numpy(), | |
| "observation.images.image_2": f["observation"]["image_2"].numpy(), | |
| "observation.images.image_3": f["observation"]["image_3"].numpy(), | |
| "discount": np.array([f["discount"].numpy()]), | |
| "reward": np.array([f["reward"].numpy()]), | |
| "action": f["action"].numpy(), | |
| } | |
| # language_instruction is also stored as "task" to follow LeRobot standard | |
| frame["task"] = frame["language_instruction"] | |
| # Meta data that are the same for all frames in the episode | |
| frame.update(frame_meta) | |
| # Cast fp64 to fp32 | |
| for key in frame: | |
| if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64: | |
| frame[key] = frame[key].astype(np.float32) | |
| yield frame | |
| def port_bridgev2( | |
| raw_dir: Path, | |
| repo_id: str, | |
| push_to_hub: bool = False, | |
| ): | |
| raw_dataset = tfds.load("bridge_dataset", data_dir=raw_dir, split="train").take(100) | |
| lerobot_dataset = LeRobotDataset.create( | |
| repo_id=repo_id, | |
| robot_type=BRIDGEV2_ROBOT_TYPE, | |
| fps=BRIDGEV2_FPS, | |
| features=BRIDGEV2_FEATURES, | |
| ) | |
| start_time = time.time() | |
| num_episodes = raw_dataset.cardinality().numpy().item() | |
| logging.info(f"Number of episodes {num_episodes}") | |
| for episode_index, episode in enumerate(raw_dataset): | |
| elapsed_time = time.time() - start_time | |
| logging.info( | |
| f"{episode_index} / {num_episodes} episodes processed (after {elapsed_time:.3f} seconds)" | |
| ) | |
| for frame in generate_lerobot_frames(episode): | |
| lerobot_dataset.add_frame(frame) | |
| lerobot_dataset.save_episode() | |
| logging.info("Save_episode") | |
| if push_to_hub: | |
| lerobot_dataset.push_to_hub( | |
| # Add openx tag, since it belongs to the openx collection of datasets | |
| tags=["openx", "bridgev2"], | |
| private=False, | |
| ) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--raw-dir", | |
| type=Path, | |
| help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).", | |
| default="./BridgeV2/rail.eecs.berkeley.edu/datasets/bridge_release/data/tfds/", | |
| ) | |
| parser.add_argument( | |
| "--repo-id", | |
| type=str, | |
| help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True", | |
| default="", | |
| ) | |
| parser.add_argument( | |
| "--push-to-hub", | |
| action="store_true", | |
| help="Upload to hub.", | |
| default=False, | |
| ) | |
| args = parser.parse_args() | |
| port_bridgev2(**vars(args)) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment