Commit
·
edc7860
1
Parent(s):
e5605f2
add streaming support
Browse files- hoho/hoho.py +30 -5
hoho/hoho.py
CHANGED
|
@@ -3,8 +3,13 @@ import json
|
|
| 3 |
import shutil
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Dict
|
|
|
|
| 6 |
|
| 7 |
from PIL import ImageFile
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 9 |
|
| 10 |
LOCAL_DATADIR = None
|
|
@@ -29,11 +34,11 @@ def setup(local_dir='./data/usm-training-data/data'):
|
|
| 29 |
else:
|
| 30 |
LOCAL_DATADIR = local_val_datadir
|
| 31 |
print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
|
| 32 |
-
|
| 33 |
-
# os.system("ls -lahtr")
|
| 34 |
-
# os.system(f"ls -lahtr {LOCAL_DATADIR}")
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
| 37 |
return LOCAL_DATADIR
|
| 38 |
|
| 39 |
|
|
@@ -286,7 +291,9 @@ def get_params():
|
|
| 286 |
import webdataset as wds
|
| 287 |
import numpy as np
|
| 288 |
|
| 289 |
-
|
|
|
|
|
|
|
| 290 |
if LOCAL_DATADIR is None:
|
| 291 |
raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
|
| 292 |
|
|
@@ -295,8 +302,24 @@ def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset
|
|
| 295 |
local_dir = local_dir / split
|
| 296 |
|
| 297 |
paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
dataset = wds.WebDataset(paths)
|
|
|
|
| 300 |
if decode is not None:
|
| 301 |
dataset = dataset.decode(decode)
|
| 302 |
else:
|
|
@@ -315,6 +338,8 @@ def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset
|
|
| 315 |
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
| 316 |
elif split == 'val':
|
| 317 |
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
|
|
|
|
|
|
| 318 |
|
| 319 |
|
| 320 |
|
|
|
|
| 3 |
import shutil
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Dict
|
| 6 |
+
import warnings
|
| 7 |
|
| 8 |
from PIL import ImageFile
|
| 9 |
+
|
| 10 |
+
from huggingface_hub.utils._headers import build_hf_headers # note: using _headers
|
| 11 |
+
|
| 12 |
+
|
| 13 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 14 |
|
| 15 |
LOCAL_DATADIR = None
|
|
|
|
| 34 |
else:
|
| 35 |
LOCAL_DATADIR = local_val_datadir
|
| 36 |
print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
if not LOCAL_DATADIR.exists():
|
| 39 |
+
warnings.warn(f"Data directory {LOCAL_DATADIR} does not exist: creating it...")
|
| 40 |
+
LOCAL_DATADIR.mkdir(parents=True)
|
| 41 |
+
|
| 42 |
return LOCAL_DATADIR
|
| 43 |
|
| 44 |
|
|
|
|
| 291 |
import webdataset as wds
|
| 292 |
import numpy as np
|
| 293 |
|
| 294 |
+
|
| 295 |
+
SHARD_IDS = {'train': (0, 25), 'val': (25, 26), 'public': (26, 27), 'private': (27, 32)}
|
| 296 |
+
def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset', stream=True):
|
| 297 |
if LOCAL_DATADIR is None:
|
| 298 |
raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
|
| 299 |
|
|
|
|
| 302 |
local_dir = local_dir / split
|
| 303 |
|
| 304 |
paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
|
| 305 |
+
msg = f'no tarfiles found in {local_dir}.'
|
| 306 |
+
if len(paths) == 0:
|
| 307 |
+
if stream:
|
| 308 |
+
if split=='all': split = 'train'
|
| 309 |
+
warnings.warn('streaming isn\'t using with \'all\': changing `split` to \'train\'')
|
| 310 |
+
warnings.warn(msg)
|
| 311 |
+
if split == 'val':
|
| 312 |
+
names = [f'data/val/inputs/hoho_v3_{i:03}-of-032.tar.gz' for i in range(*SHARD_IDS[split])]
|
| 313 |
+
elif split == 'train':
|
| 314 |
+
names = [f'data/train/hoho_v3_{i:03}-of-032.tar.gz' for i in range(*SHARD_IDS[split])]
|
| 315 |
+
|
| 316 |
+
auth = build_hf_headers()['authorization']
|
| 317 |
+
paths = [f"pipe:curl -L -s https://huggingface.co/datasets/usm3d/hoho-train-set/resolve/main/{name} -H 'Authorization: {auth}'" for name in names]
|
| 318 |
+
else:
|
| 319 |
+
raise FileNotFoundError(msg)
|
| 320 |
|
| 321 |
dataset = wds.WebDataset(paths)
|
| 322 |
+
|
| 323 |
if decode is not None:
|
| 324 |
dataset = dataset.decode(decode)
|
| 325 |
else:
|
|
|
|
| 338 |
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
| 339 |
elif split == 'val':
|
| 340 |
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
| 341 |
+
else:
|
| 342 |
+
raise NotImplementedError('only train and val are implemented as hf datasets')
|
| 343 |
|
| 344 |
|
| 345 |
|