Skip to content

Instantly share code, notes, and snippets.

@Qu3tzal
Created January 8, 2026 09:06
Show Gist options
  • Select an option

  • Save Qu3tzal/bdfd34892292877351726ed2204cb45e to your computer and use it in GitHub Desktop.

Select an option

Save Qu3tzal/bdfd34892292877351726ed2204cb45e to your computer and use it in GitHub Desktop.
Convert BridgeV2 TFDS dataset to LeRobot compatible dataset.
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import time
from pathlib import Path
import numpy as np
import tensorflow_datasets as tfds
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
BRIDGEV2_SHARDS = 1024
BRIDGEV2_FPS = 5
BRIDGEV2_ROBOT_TYPE = "WidowX"
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
BRIDGEV2_FEATURES = {
"has_language": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
"has_img_0": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
"has_img_1": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
"has_img_2": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
"has_img_3": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# true on first step of the episode
"is_first": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# true on last step of the episode
"is_last": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# true on last step of the episode if it is a terminal step, True for demos
"is_terminal": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# language_instruction is also stored as "task" to follow LeRobot standard
"language_instruction": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"observation.state": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
},
},
"observation.images.image_0": {
"dtype": "image",
"shape": (256, 256, 3),
"names": [
"height",
"width",
"channels",
],
},
"observation.images.image_1": {
"dtype": "image",
"shape": (256, 256, 3),
"names": [
"height",
"width",
"channels",
],
},
"observation.images.image_2": {
"dtype": "image",
"shape": (256, 256, 3),
"names": [
"height",
"width",
"channels",
],
},
"observation.images.image_3": {
"dtype": "image",
"shape": (256, 256, 3),
"names": [
"height",
"width",
"channels",
],
},
"action": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
},
},
"discount": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
}
def generate_lerobot_frames(tf_episode):
m = tf_episode["episode_metadata"]
frame_meta = {
"has_img_0": np.array([m["has_image_0"].numpy()]),
"has_img_1": np.array([m["has_image_1"].numpy()]),
"has_img_2": np.array([m["has_image_2"].numpy()]),
"has_img_3": np.array([m["has_image_3"].numpy()]),
"has_language": np.array([m["has_language"].numpy()]),
}
for f in tf_episode["steps"]:
frame = {
"is_first": np.array([f["is_first"].numpy()]),
"is_last": np.array([f["is_last"].numpy()]),
"is_terminal": np.array([f["is_terminal"].numpy()]),
"language_instruction": f["language_instruction"].numpy().decode(),
"observation.state": f["observation"]["state"].numpy(),
"observation.images.image_0": f["observation"]["image_0"].numpy(),
"observation.images.image_1": f["observation"]["image_1"].numpy(),
"observation.images.image_2": f["observation"]["image_2"].numpy(),
"observation.images.image_3": f["observation"]["image_3"].numpy(),
"discount": np.array([f["discount"].numpy()]),
"reward": np.array([f["reward"].numpy()]),
"action": f["action"].numpy(),
}
# language_instruction is also stored as "task" to follow LeRobot standard
frame["task"] = frame["language_instruction"]
# Meta data that are the same for all frames in the episode
frame.update(frame_meta)
# Cast fp64 to fp32
for key in frame:
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
frame[key] = frame[key].astype(np.float32)
yield frame
def port_bridgev2(
raw_dir: Path,
repo_id: str,
push_to_hub: bool = False,
):
raw_dataset = tfds.load("bridge_dataset", data_dir=raw_dir, split="train").take(100)
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
robot_type=BRIDGEV2_ROBOT_TYPE,
fps=BRIDGEV2_FPS,
features=BRIDGEV2_FEATURES,
)
start_time = time.time()
num_episodes = raw_dataset.cardinality().numpy().item()
logging.info(f"Number of episodes {num_episodes}")
for episode_index, episode in enumerate(raw_dataset):
elapsed_time = time.time() - start_time
logging.info(
f"{episode_index} / {num_episodes} episodes processed (after {elapsed_time:.3f} seconds)"
)
for frame in generate_lerobot_frames(episode):
lerobot_dataset.add_frame(frame)
lerobot_dataset.save_episode()
logging.info("Save_episode")
if push_to_hub:
lerobot_dataset.push_to_hub(
# Add openx tag, since it belongs to the openx collection of datasets
tags=["openx", "bridgev2"],
private=False,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
default="./BridgeV2/rail.eecs.berkeley.edu/datasets/bridge_release/data/tfds/",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
default="",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload to hub.",
default=False,
)
args = parser.parse_args()
port_bridgev2(**vars(args))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment