Spaces:
Running
Running
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import traceback | |
from pathlib import Path | |
from datasets import get_dataset_config_info | |
from huggingface_hub import HfApi | |
from lerobot import available_datasets | |
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata | |
from lerobot.common.datasets.utils import INFO_PATH, write_info | |
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings | |
LOCAL_DIR = Path("data/") | |
hub_api = HfApi() | |
def fix_dataset(repo_id: str) -> str: | |
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"): | |
return f"{repo_id}: skipped (not in {V20})." | |
dataset_info = get_dataset_config_info(repo_id, "default") | |
with SuppressWarnings(): | |
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True) | |
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"} | |
parquet_features = set(dataset_info.features) | |
diff_parquet_meta = parquet_features - meta_features | |
diff_meta_parquet = meta_features - parquet_features | |
if diff_parquet_meta: | |
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}") | |
if not diff_meta_parquet: | |
return f"{repo_id}: skipped (no diff)" | |
if diff_meta_parquet: | |
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}") | |
assert diff_meta_parquet == {"language_instruction"} | |
lerobot_metadata.features.pop("language_instruction") | |
write_info(lerobot_metadata.info, lerobot_metadata.root) | |
commit_info = hub_api.upload_file( | |
path_or_fileobj=lerobot_metadata.root / INFO_PATH, | |
path_in_repo=INFO_PATH, | |
repo_id=repo_id, | |
repo_type="dataset", | |
revision=V20, | |
commit_message="Remove 'language_instruction'", | |
create_pr=True, | |
) | |
return f"{repo_id}: success - PR: {commit_info.pr_url}" | |
def batch_fix(): | |
status = {} | |
LOCAL_DIR.mkdir(parents=True, exist_ok=True) | |
logfile = LOCAL_DIR / "fix_features_v20.txt" | |
for num, repo_id in enumerate(available_datasets): | |
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})") | |
print("---------------------------------------------------------") | |
try: | |
status = fix_dataset(repo_id) | |
except Exception: | |
status = f"{repo_id}: failed\n {traceback.format_exc()}" | |
logging.info(status) | |
with open(logfile, "a") as file: | |
file.write(status + "\n") | |
if __name__ == "__main__": | |
batch_fix() | |