not-lain commited on
Commit
c7892ad
·
verified ·
1 Parent(s): c1b939f

Create dataset_uploader.py

Browse files
Files changed (1) hide show
  1. dataset_uploader.py +174 -0
dataset_uploader.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ import uuid
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ import pyarrow as pa
9
+ import pyarrow.parquet as pq
10
+ from huggingface_hub import CommitScheduler
11
+ from huggingface_hub.hf_api import HfApi
12
+
13
+ ###################################
14
+ # Parquet scheduler #
15
+ # Uploads data in parquet format #
16
+ ###################################
17
+
18
+
19
+ class ParquetScheduler(CommitScheduler):
20
+ """
21
+ Usage: configure the scheduler with a repo id. Once started, you can add data to be uploaded to the Hub. 1 `.append`
22
+ call will result in 1 row in your final dataset.
23
+
24
+ ```py
25
+ # Start scheduler
26
+ >>> scheduler = ParquetScheduler(repo_id="my-parquet-dataset")
27
+
28
+ # Append some data to be uploaded
29
+ >>> scheduler.append({...})
30
+ >>> scheduler.append({...})
31
+ >>> scheduler.append({...})
32
+ ```
33
+
34
+ The scheduler will automatically infer the schema from the data it pushes.
35
+ Optionally, you can manually set the schema yourself:
36
+
37
+ ```py
38
+ >>> scheduler = ParquetScheduler(
39
+ ... repo_id="my-parquet-dataset",
40
+ ... schema={
41
+ ... "prompt": {"_type": "Value", "dtype": "string"},
42
+ ... "negative_prompt": {"_type": "Value", "dtype": "string"},
43
+ ... "guidance_scale": {"_type": "Value", "dtype": "int64"},
44
+ ... "image": {"_type": "Image"},
45
+ ... },
46
+ ... )
47
+
48
+ See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of
49
+ possible values.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ *,
55
+ repo_id: str,
56
+ schema: Optional[Dict[str, Dict[str, str]]] = None,
57
+ every: Union[int, float] = 5,
58
+ path_in_repo: Optional[str] = "data",
59
+ repo_type: Optional[str] = "dataset",
60
+ revision: Optional[str] = None,
61
+ private: bool = False,
62
+ token: Optional[str] = None,
63
+ allow_patterns: Union[List[str], str, None] = None,
64
+ ignore_patterns: Union[List[str], str, None] = None,
65
+ hf_api: Optional[HfApi] = None,
66
+ ) -> None:
67
+ super().__init__(
68
+ repo_id=repo_id,
69
+ folder_path="dummy", # not used by the scheduler
70
+ every=every,
71
+ path_in_repo=path_in_repo,
72
+ repo_type=repo_type,
73
+ revision=revision,
74
+ private=private,
75
+ token=token,
76
+ allow_patterns=allow_patterns,
77
+ ignore_patterns=ignore_patterns,
78
+ hf_api=hf_api,
79
+ )
80
+
81
+ self._rows: List[Dict[str, Any]] = []
82
+ self._schema = schema
83
+
84
+ def append(self, row: Dict[str, Any]) -> None:
85
+ """Add a new item to be uploaded."""
86
+ with self.lock:
87
+ self._rows.append(row)
88
+
89
+ def push_to_hub(self):
90
+ # Check for new rows to push
91
+ with self.lock:
92
+ rows = self._rows
93
+ self._rows = []
94
+ if not rows:
95
+ return
96
+ print(f"Got {len(rows)} item(s) to commit.")
97
+
98
+ # Load images + create 'features' config for datasets library
99
+ schema: Dict[str, Dict] = self._schema or {}
100
+ path_to_cleanup: List[Path] = []
101
+ for row in rows:
102
+ for key, value in row.items():
103
+ # Infer schema (for `datasets` library)
104
+ if key not in schema:
105
+ schema[key] = _infer_schema(key, value)
106
+
107
+ # Load binary files if necessary
108
+ if schema[key]["_type"] in ("Image", "Audio"):
109
+ # It's an image or audio: we load the bytes and remember to cleanup the file
110
+ file_path = Path(value)
111
+ if file_path.is_file():
112
+ row[key] = {
113
+ "path": file_path.name,
114
+ "bytes": file_path.read_bytes(),
115
+ }
116
+ path_to_cleanup.append(file_path)
117
+
118
+ # Complete rows if needed
119
+ for row in rows:
120
+ for feature in schema:
121
+ if feature not in row:
122
+ row[feature] = None
123
+
124
+ # Export items to Arrow format
125
+ table = pa.Table.from_pylist(rows)
126
+
127
+ # Add metadata (used by datasets library)
128
+ table = table.replace_schema_metadata(
129
+ {"huggingface": json.dumps({"info": {"features": schema}})}
130
+ )
131
+
132
+ # Write to parquet file
133
+ archive_file = tempfile.NamedTemporaryFile(delete=False)
134
+ pq.write_table(table, archive_file.name)
135
+ archive_file.close()
136
+
137
+ # Upload
138
+ self.api.upload_file(
139
+ repo_id=self.repo_id,
140
+ repo_type=self.repo_type,
141
+ revision=self.revision,
142
+ path_in_repo=f"{uuid.uuid4()}.parquet",
143
+ path_or_fileobj=archive_file.name,
144
+ )
145
+ print("Commit completed.")
146
+
147
+ # Cleanup
148
+ os.unlink(archive_file.name)
149
+ for path in path_to_cleanup:
150
+ path.unlink(missing_ok=True)
151
+
152
+
153
+ def _infer_schema(key: str, value: Any) -> Dict[str, str]:
154
+ """
155
+ Infer schema for the `datasets` library.
156
+
157
+ See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value.
158
+ """
159
+ # In short any column_name in the dataset has any of these keywords
160
+ # the column will be inferred into the correct column type accordingly
161
+ if "image" in key:
162
+ return {"_type": "Image"}
163
+ if "audio" in key:
164
+ return {"_type": "Audio"}
165
+ if isinstance(value, int):
166
+ return {"_type": "Value", "dtype": "int64"}
167
+ if isinstance(value, float):
168
+ return {"_type": "Value", "dtype": "float64"}
169
+ if isinstance(value, bool):
170
+ return {"_type": "Value", "dtype": "bool"}
171
+ if isinstance(value, bytes):
172
+ return {"_type": "Value", "dtype": "binary"}
173
+ # Otherwise in last resort => convert it to a string
174
+ return {"_type": "Value", "dtype": "string"}