davanstrien HF staff commited on
Commit
059d73d
·
verified ·
1 Parent(s): 06cab33

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +247 -0
  2. requirements.txt +277 -0
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from httpx import Client
3
+ import random
4
+ import os
5
+ import fasttext
6
+ from huggingface_hub import hf_hub_download
7
+ from typing import Union
8
+ from typing import Iterator
9
+ from dotenv import load_dotenv
10
+ from toolz import groupby, valmap, concat
11
+ from statistics import mean
12
+ from httpx import Timeout
13
+ from huggingface_hub.utils import logging
14
+ from litestar import get
15
+ from httpx import AsyncClient
16
+
17
+ import random
18
+ import asyncio
19
+ import httpx
20
+
21
+ # ...
22
+ from litestar import Litestar, get
23
+
24
+ logger = logging.get_logger(__name__)
25
+ load_dotenv()
26
+ HF_TOKEN = os.getenv("HF_TOKEN")
27
+
28
+
29
+ BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
30
+ DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID"
31
+ headers = {
32
+ "authorization": f"Bearer ${HF_TOKEN}",
33
+ }
34
+ timeout = Timeout(60, read=120)
35
+ client = Client(headers=headers, timeout=timeout)
36
+ async_client = AsyncClient(headers=headers, timeout=timeout)
37
+ # non exhaustive list of columns that might contain text which can be used for language detection
38
+ # we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
39
+ TARGET_COLUMN_NAMES = {
40
+ "text",
41
+ "input",
42
+ "tokens",
43
+ "prompt",
44
+ "instruction",
45
+ "sentence_1",
46
+ "question",
47
+ "sentence2",
48
+ "answer",
49
+ "sentence",
50
+ "response",
51
+ "context",
52
+ "query",
53
+ "chosen",
54
+ "rejected",
55
+ }
56
+
57
+
58
+ def datasets_server_valid_rows(hub_id: str):
59
+ resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
60
+ resp.raise_for_status()
61
+ return resp.json()["viewer"]
62
+
63
+
64
+ def get_first_config_and_split_name(hub_id: str):
65
+ resp = client.get(f"https://datasets-server.huggingface.co/splits?dataset={hub_id}")
66
+ resp.raise_for_status()
67
+ data = resp.json()
68
+ return data["splits"][0]["config"], data["splits"][0]["split"]
69
+
70
+
71
+ def get_dataset_info(hub_id: str, config: str | None = None):
72
+ if config is None:
73
+ config = get_first_config_and_split_name(hub_id)
74
+ if config is None:
75
+ return None
76
+ else:
77
+ config = config[0]
78
+ resp = client.get(
79
+ f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
80
+ )
81
+ resp.raise_for_status()
82
+ return resp.json()
83
+
84
+
85
+ async def get_random_rows(
86
+ hub_id: str,
87
+ total_length: int,
88
+ number_of_rows: int,
89
+ max_request_calls: int,
90
+ config="default",
91
+ split="train",
92
+ ):
93
+ rows = []
94
+ rows_per_call = min(
95
+ number_of_rows // max_request_calls, total_length // max_request_calls
96
+ )
97
+ rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100
98
+ for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
99
+ offset = random.randint(0, total_length - rows_per_call)
100
+ url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
101
+ response = await async_client.get(url)
102
+ if response.status_code == 200:
103
+ data = response.json()
104
+ batch_rows = data.get("rows")
105
+ rows.extend(batch_rows)
106
+ else:
107
+ print(f"Failed to fetch data: {response.status_code}")
108
+ print(url)
109
+ if len(rows) >= number_of_rows:
110
+ break
111
+ return [row.get("row") for row in rows]
112
+
113
+
114
+ def load_model(repo_id: str) -> fasttext.FastText._FastText:
115
+ model_path = hf_hub_download(repo_id, filename="model.bin")
116
+ return fasttext.load_model(model_path)
117
+
118
+
119
+ def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
120
+ for row in rows:
121
+ if isinstance(row, str):
122
+ # split on lines and remove empty lines
123
+ line = row.split("\n")
124
+ for line in line:
125
+ if line:
126
+ yield line
127
+ elif isinstance(row, list):
128
+ try:
129
+ line = " ".join(row)
130
+ if len(line) < min_length:
131
+ continue
132
+ else:
133
+ yield line
134
+ except TypeError:
135
+ continue
136
+
137
+
138
+ FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
139
+
140
+ # model = load_model(DEFAULT_FAST_TEXT_MODEL)
141
+
142
+ model = fasttext.load_model(
143
+ hf_hub_download("facebook/fasttext-language-identification", "model.bin")
144
+ )
145
+
146
+
147
+ def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
148
+ predictions = model.predict(inputs, k=k)
149
+ return [
150
+ {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob}
151
+ for label, prob in zip(predictions[0], predictions[1])
152
+ ]
153
+
154
+
155
+ def get_label(x):
156
+ return x.get("label")
157
+
158
+
159
+ def get_mean_score(preds):
160
+ return mean([pred.get("score") for pred in preds])
161
+
162
+
163
+ def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
164
+ """Filter a dict to include items whose value is above `threshold_percent`"""
165
+ total = sum(counts_dict.values())
166
+ threshold = total * threshold_percent
167
+ return {k for k, v in counts_dict.items() if v >= threshold}
168
+
169
+
170
+ def predict_rows(rows, target_column, language_threshold_percent=0.2):
171
+ rows = (row.get(target_column) for row in rows)
172
+ rows = (row for row in rows if row is not None)
173
+ rows = list(yield_clean_rows(rows))
174
+ predictions = [model_predict(row) for row in rows]
175
+ predictions = [pred for pred in predictions if pred is not None]
176
+ predictions = list(concat(predictions))
177
+ predictions_by_lang = groupby(get_label, predictions)
178
+ langues_counts = valmap(len, predictions_by_lang)
179
+ keys_to_keep = filter_by_frequency(
180
+ langues_counts, threshold_percent=language_threshold_percent
181
+ )
182
+ filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
183
+ return {
184
+ "predictions": dict(valmap(get_mean_score, filtered_dict)),
185
+ "pred": predictions,
186
+ }
187
+
188
+
189
+ @get("/predict_language/")
190
+ async def predict_language(
191
+ hub_id: str,
192
+ config: str | None = None,
193
+ split: str | None = None,
194
+ max_request_calls: int = 10,
195
+ number_of_rows: int = 1000,
196
+ ) -> dict[str, float | str]:
197
+ is_valid = datasets_server_valid_rows(hub_id)
198
+ if not is_valid:
199
+ gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
200
+ if not config:
201
+ config, split = get_first_config_and_split_name(hub_id)
202
+ info = get_dataset_info(hub_id, config)
203
+ if info is None:
204
+ gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
205
+ if dataset_info := info.get("dataset_info"):
206
+ total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
207
+ logger.info(f"Total rows for split {split}: {total_rows_for_split}")
208
+ features = dataset_info.get("features")
209
+ column_names = set(features.keys())
210
+ logger.info(f"Column names: {column_names}")
211
+ if not set(column_names).intersection(TARGET_COLUMN_NAMES):
212
+ raise gr.Error(
213
+ f"Dataset {hub_id} does not contain any of the target columns {TARGET_COLUMN_NAMES}"
214
+ )
215
+ for column in TARGET_COLUMN_NAMES:
216
+ if column in column_names:
217
+ target_column = column
218
+ logger.info(f"Using column {target_column} for language detection")
219
+ break
220
+ random_rows = await get_random_rows(
221
+ hub_id,
222
+ total_rows_for_split,
223
+ number_of_rows,
224
+ max_request_calls,
225
+ config,
226
+ split,
227
+ )
228
+ logger.info(f"Predicting language for {len(random_rows)} rows")
229
+ predictions = predict_rows(random_rows, target_column)
230
+ predictions["hub_id"] = hub_id
231
+ predictions["config"] = config
232
+ predictions["split"] = split
233
+ return predictions
234
+
235
+
236
+ app = Litestar([predict_language])
237
+ # inputs = [
238
+ # gr.Text(label="dataset id"),
239
+ # gr.Textbox(
240
+ # None,
241
+ # label="config",
242
+ # ),
243
+ # gr.Textbox(None, label="split"),
244
+ # ]
245
+ # interface = gr.Interface(predict_language, inputs=inputs, outputs="json")
246
+ # interface.queue()
247
+ # interface.launch()
requirements.txt ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.11
3
+ # by the following command:
4
+ #
5
+ # pip-compile
6
+ #
7
+ aiofiles==23.2.1
8
+ # via gradio
9
+ aiohttp==3.9.1
10
+ # via
11
+ # datasets
12
+ # fsspec
13
+ aiosignal==1.3.1
14
+ # via aiohttp
15
+ altair==5.2.0
16
+ # via gradio
17
+ annotated-types==0.6.0
18
+ # via pydantic
19
+ anyio==4.2.0
20
+ # via
21
+ # httpx
22
+ # litestar
23
+ # starlette
24
+ attrs==23.2.0
25
+ # via
26
+ # aiohttp
27
+ # jsonschema
28
+ # referencing
29
+ certifi==2023.11.17
30
+ # via
31
+ # httpcore
32
+ # httpx
33
+ # requests
34
+ charset-normalizer==3.3.2
35
+ # via requests
36
+ click==8.1.7
37
+ # via
38
+ # litestar
39
+ # rich-click
40
+ # typer
41
+ # uvicorn
42
+ colorama==0.4.6
43
+ # via typer
44
+ contourpy==1.2.0
45
+ # via matplotlib
46
+ cycler==0.12.1
47
+ # via matplotlib
48
+ datasets==2.14.4
49
+ # via -r requirements.in
50
+ dill==0.3.7
51
+ # via
52
+ # datasets
53
+ # multiprocess
54
+ faker==22.5.0
55
+ # via polyfactory
56
+ fastapi==0.109.0
57
+ # via gradio
58
+ fasttext==0.9.2
59
+ # via -r requirements.in
60
+ ffmpy==0.3.1
61
+ # via gradio
62
+ filelock==3.13.1
63
+ # via huggingface-hub
64
+ fonttools==4.47.2
65
+ # via matplotlib
66
+ frozenlist==1.4.1
67
+ # via
68
+ # aiohttp
69
+ # aiosignal
70
+ fsspec[http]==2023.12.2
71
+ # via
72
+ # datasets
73
+ # gradio-client
74
+ # huggingface-hub
75
+ gradio==4.15.0
76
+ # via -r requirements.in
77
+ gradio-client==0.8.1
78
+ # via gradio
79
+ h11==0.14.0
80
+ # via
81
+ # httpcore
82
+ # uvicorn
83
+ httpcore==1.0.2
84
+ # via httpx
85
+ httpx==0.26.0
86
+ # via
87
+ # -r requirements.in
88
+ # gradio
89
+ # gradio-client
90
+ # litestar
91
+ huggingface-hub==0.20.3
92
+ # via
93
+ # -r requirements.in
94
+ # datasets
95
+ # gradio
96
+ # gradio-client
97
+ idna==3.6
98
+ # via
99
+ # anyio
100
+ # httpx
101
+ # requests
102
+ # yarl
103
+ importlib-resources==6.1.1
104
+ # via gradio
105
+ iso639-lang==2.2.2
106
+ # via -r requirements.in
107
+ jinja2==3.1.3
108
+ # via
109
+ # altair
110
+ # gradio
111
+ jsonschema==4.21.1
112
+ # via altair
113
+ jsonschema-specifications==2023.12.1
114
+ # via jsonschema
115
+ kiwisolver==1.4.5
116
+ # via matplotlib
117
+ litestar==2.5.1
118
+ # via -r requirements.in
119
+ markdown-it-py==3.0.0
120
+ # via rich
121
+ markupsafe==2.1.4
122
+ # via
123
+ # gradio
124
+ # jinja2
125
+ matplotlib==3.8.2
126
+ # via gradio
127
+ mdurl==0.1.2
128
+ # via markdown-it-py
129
+ msgspec==0.18.6
130
+ # via litestar
131
+ multidict==6.0.4
132
+ # via
133
+ # aiohttp
134
+ # litestar
135
+ # yarl
136
+ multiprocess==0.70.15
137
+ # via datasets
138
+ numpy==1.26.3
139
+ # via
140
+ # altair
141
+ # contourpy
142
+ # datasets
143
+ # fasttext
144
+ # gradio
145
+ # matplotlib
146
+ # pandas
147
+ # pyarrow
148
+ orjson==3.9.12
149
+ # via gradio
150
+ packaging==23.2
151
+ # via
152
+ # altair
153
+ # datasets
154
+ # gradio
155
+ # gradio-client
156
+ # huggingface-hub
157
+ # matplotlib
158
+ pandas==2.2.0
159
+ # via
160
+ # altair
161
+ # datasets
162
+ # gradio
163
+ pillow==10.2.0
164
+ # via
165
+ # gradio
166
+ # matplotlib
167
+ polyfactory==2.14.1
168
+ # via litestar
169
+ pyarrow==15.0.0
170
+ # via datasets
171
+ pybind11==2.11.1
172
+ # via fasttext
173
+ pydantic==2.5.3
174
+ # via
175
+ # fastapi
176
+ # gradio
177
+ pydantic-core==2.14.6
178
+ # via pydantic
179
+ pydub==0.25.1
180
+ # via gradio
181
+ pygments==2.17.2
182
+ # via rich
183
+ pyparsing==3.1.1
184
+ # via matplotlib
185
+ python-dateutil==2.8.2
186
+ # via
187
+ # faker
188
+ # matplotlib
189
+ # pandas
190
+ python-dotenv==1.0.1
191
+ # via -r requirements.in
192
+ python-multipart==0.0.6
193
+ # via gradio
194
+ pytz==2023.3.post1
195
+ # via pandas
196
+ pyyaml==6.0.1
197
+ # via
198
+ # datasets
199
+ # gradio
200
+ # huggingface-hub
201
+ # litestar
202
+ referencing==0.32.1
203
+ # via
204
+ # jsonschema
205
+ # jsonschema-specifications
206
+ requests==2.31.0
207
+ # via
208
+ # datasets
209
+ # fsspec
210
+ # huggingface-hub
211
+ rich==13.7.0
212
+ # via
213
+ # -r requirements.in
214
+ # litestar
215
+ # rich-click
216
+ # typer
217
+ rich-click==1.7.3
218
+ # via litestar
219
+ rpds-py==0.17.1
220
+ # via
221
+ # jsonschema
222
+ # referencing
223
+ ruff==0.1.14
224
+ # via gradio
225
+ semantic-version==2.10.0
226
+ # via gradio
227
+ shellingham==1.5.4
228
+ # via typer
229
+ six==1.16.0
230
+ # via python-dateutil
231
+ sniffio==1.3.0
232
+ # via
233
+ # anyio
234
+ # httpx
235
+ starlette==0.35.1
236
+ # via fastapi
237
+ tomlkit==0.12.0
238
+ # via gradio
239
+ toolz==0.12.0
240
+ # via
241
+ # -r requirements.in
242
+ # altair
243
+ tqdm==4.66.1
244
+ # via
245
+ # datasets
246
+ # huggingface-hub
247
+ typer[all]==0.9.0
248
+ # via
249
+ # gradio
250
+ # typer
251
+ typing-extensions==4.9.0
252
+ # via
253
+ # fastapi
254
+ # gradio
255
+ # gradio-client
256
+ # huggingface-hub
257
+ # litestar
258
+ # polyfactory
259
+ # pydantic
260
+ # pydantic-core
261
+ # rich-click
262
+ # typer
263
+ tzdata==2023.4
264
+ # via pandas
265
+ urllib3==2.1.0
266
+ # via requests
267
+ uvicorn==0.27.0
268
+ # via gradio
269
+ websockets==11.0.3
270
+ # via gradio-client
271
+ xxhash==3.4.1
272
+ # via datasets
273
+ yarl==1.9.4
274
+ # via aiohttp
275
+
276
+ # The following packages are considered to be unsafe in a requirements file:
277
+ # setuptools