davidberenstein1957 HF staff commited on
Commit
342e737
ยท
2 Parent(s): e044b6a d5933e1

Merge branch 'main' of https://github.com/argilla-io/synthetic-data-generator into feat/improve-textcat

Browse files
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -1,19 +1,13 @@
1
  import io
2
  import uuid
3
- from typing import List, Union
4
 
5
  import argilla as rg
6
  import gradio as gr
7
- import pandas as pd
8
- from datasets import ClassLabel, Dataset, Features, Sequence, Value
9
- from distilabel.distiset import Distiset
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file
12
 
13
- from synthetic_dataset_generator.constants import TEXTCAT_TASK
14
- from synthetic_dataset_generator.utils import (
15
- get_argilla_client,
16
- )
17
 
18
 
19
  def validate_argilla_user_workspace_dataset(
@@ -52,7 +46,7 @@ def push_pipeline_code_to_hub(
52
  oauth_token: Union[OAuthToken, None] = None,
53
  progress=gr.Progress(),
54
  ):
55
- repo_id = validate_push_to_hub(org_name, repo_name)
56
  progress(0.1, desc="Uploading pipeline code")
57
  with io.BytesIO(pipeline_code.encode("utf-8")) as f:
58
  upload_file(
 
1
  import io
2
  import uuid
3
+ from typing import Union
4
 
5
  import argilla as rg
6
  import gradio as gr
 
 
 
7
  from gradio import OAuthToken
8
  from huggingface_hub import HfApi, upload_file
9
 
10
+ from synthetic_dataset_generator.utils import get_argilla_client
 
 
 
11
 
12
 
13
  def validate_argilla_user_workspace_dataset(
 
46
  oauth_token: Union[OAuthToken, None] = None,
47
  progress=gr.Progress(),
48
  ):
49
+ repo_id: str | None = validate_push_to_hub(org_name, repo_name)
50
  progress(0.1, desc="Uploading pipeline code")
51
  with io.BytesIO(pipeline_code.encode("utf-8")) as f:
52
  upload_file(
src/synthetic_dataset_generator/apps/eval.py CHANGED
@@ -19,6 +19,7 @@ from huggingface_hub import HfApi, repo_exists
19
 
20
  from synthetic_dataset_generator.apps.base import (
21
  hide_success_message,
 
22
  show_success_message,
23
  validate_argilla_user_workspace_dataset,
24
  validate_push_to_hub,
@@ -346,7 +347,12 @@ def evaluate_sample_dataset(
346
 
347
 
348
  def push_dataset_to_hub(
349
- dataframe: pd.DataFrame, org_name: str, repo_name: str, oauth_token, private
 
 
 
 
 
350
  ):
351
  repo_id = validate_push_to_hub(org_name, repo_name)
352
  distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
@@ -357,6 +363,7 @@ def push_dataset_to_hub(
357
  token=oauth_token.token,
358
  create_pr=False,
359
  )
 
360
 
361
 
362
  def push_dataset(
@@ -371,6 +378,7 @@ def push_dataset(
371
  response_instruction_response: str,
372
  prompt_template: str,
373
  structured_output: dict,
 
374
  oauth_token: Union[gr.OAuthToken, None] = None,
375
  progress=gr.Progress(),
376
  ) -> pd.DataFrame:
@@ -385,7 +393,9 @@ def push_dataset(
385
  structured_output=structured_output,
386
  num_rows=num_rows,
387
  )
388
- push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
 
 
389
  try:
390
  progress(0.1, desc="Setting up user and workspace")
391
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
@@ -854,6 +864,7 @@ with gr.Blocks() as app:
854
  response_instruction_response,
855
  prompt_template,
856
  structured_output,
 
857
  ],
858
  outputs=[success_message],
859
  show_progress=True,
 
19
 
20
  from synthetic_dataset_generator.apps.base import (
21
  hide_success_message,
22
+ push_pipeline_code_to_hub,
23
  show_success_message,
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
 
347
 
348
 
349
  def push_dataset_to_hub(
350
+ dataframe: pd.DataFrame,
351
+ org_name: str,
352
+ repo_name: str,
353
+ oauth_token: Union[gr.OAuthToken, None],
354
+ private: bool,
355
+ pipeline_code: str,
356
  ):
357
  repo_id = validate_push_to_hub(org_name, repo_name)
358
  distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
 
363
  token=oauth_token.token,
364
  create_pr=False,
365
  )
366
+ push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
367
 
368
 
369
  def push_dataset(
 
378
  response_instruction_response: str,
379
  prompt_template: str,
380
  structured_output: dict,
381
+ pipeline_code: str,
382
  oauth_token: Union[gr.OAuthToken, None] = None,
383
  progress=gr.Progress(),
384
  ) -> pd.DataFrame:
 
393
  structured_output=structured_output,
394
  num_rows=num_rows,
395
  )
396
+ push_dataset_to_hub(
397
+ dataframe, org_name, repo_name, oauth_token, private, pipeline_code
398
+ )
399
  try:
400
  progress(0.1, desc="Setting up user and workspace")
401
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
 
864
  response_instruction_response,
865
  prompt_template,
866
  structured_output,
867
+ pipeline_code,
868
  ],
869
  outputs=[success_message],
870
  show_progress=True,
src/synthetic_dataset_generator/apps/sft.py CHANGED
@@ -11,6 +11,7 @@ from huggingface_hub import HfApi
11
 
12
  from synthetic_dataset_generator.apps.base import (
13
  hide_success_message,
 
14
  show_success_message,
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
@@ -202,7 +203,14 @@ def generate_dataset(
202
  return dataframe
203
 
204
 
205
- def push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private):
 
 
 
 
 
 
 
206
  repo_id = validate_push_to_hub(org_name, repo_name)
207
  original_dataframe = dataframe.copy(deep=True)
208
  dataframe = convert_dataframe_messages(dataframe)
@@ -214,6 +222,7 @@ def push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private):
214
  token=oauth_token.token,
215
  create_pr=False,
216
  )
 
217
  return original_dataframe
218
 
219
 
@@ -225,6 +234,7 @@ def push_dataset(
225
  num_rows: int = 10,
226
  private: bool = False,
227
  temperature: float = 0.9,
 
228
  oauth_token: Union[gr.OAuthToken, None] = None,
229
  progress=gr.Progress(),
230
  ) -> pd.DataFrame:
@@ -234,7 +244,9 @@ def push_dataset(
234
  num_rows=num_rows,
235
  temperature=temperature,
236
  )
237
- push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
 
 
238
  try:
239
  progress(0.1, desc="Setting up user and workspace")
240
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
@@ -528,6 +540,7 @@ with gr.Blocks() as app:
528
  num_rows,
529
  private,
530
  temperature,
 
531
  ],
532
  outputs=[success_message],
533
  show_progress=True,
 
11
 
12
  from synthetic_dataset_generator.apps.base import (
13
  hide_success_message,
14
+ push_pipeline_code_to_hub,
15
  show_success_message,
16
  validate_argilla_user_workspace_dataset,
17
  validate_push_to_hub,
 
203
  return dataframe
204
 
205
 
206
+ def push_dataset_to_hub(
207
+ dataframe: pd.DataFrame,
208
+ org_name: str,
209
+ repo_name: str,
210
+ oauth_token: Union[gr.OAuthToken, None],
211
+ private: bool,
212
+ pipeline_code: str,
213
+ ):
214
  repo_id = validate_push_to_hub(org_name, repo_name)
215
  original_dataframe = dataframe.copy(deep=True)
216
  dataframe = convert_dataframe_messages(dataframe)
 
222
  token=oauth_token.token,
223
  create_pr=False,
224
  )
225
+ push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
226
  return original_dataframe
227
 
228
 
 
234
  num_rows: int = 10,
235
  private: bool = False,
236
  temperature: float = 0.9,
237
+ pipeline_code: str = "",
238
  oauth_token: Union[gr.OAuthToken, None] = None,
239
  progress=gr.Progress(),
240
  ) -> pd.DataFrame:
 
244
  num_rows=num_rows,
245
  temperature=temperature,
246
  )
247
+ push_dataset_to_hub(
248
+ dataframe, org_name, repo_name, oauth_token, private, pipeline_code
249
+ )
250
  try:
251
  progress(0.1, desc="Setting up user and workspace")
252
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
 
540
  num_rows,
541
  private,
542
  temperature,
543
+ pipeline_code,
544
  ],
545
  outputs=[success_message],
546
  show_progress=True,
src/synthetic_dataset_generator/utils.py CHANGED
@@ -39,14 +39,13 @@ def list_orgs(oauth_token: Union[OAuthToken, None] = None):
39
  organizations = [org for org in organizations if org != data["name"]]
40
  organizations = [data["name"]] + organizations
41
  except Exception as e:
42
- data = whoami(oauth_token.token)
43
  warnings.warn(str(e))
44
  gr.Info(
45
  "Your user token does not have the necessary permissions to push to organizations."
46
  "Please check your OAuth permissions in https://huggingface.co/settings/connected-applications."
47
  "Update yout token permissions to include repo.write: https://huggingface.co/settings/tokens."
48
  )
49
- return [data["name"]]
50
 
51
  return organizations
52
 
 
39
  organizations = [org for org in organizations if org != data["name"]]
40
  organizations = [data["name"]] + organizations
41
  except Exception as e:
 
42
  warnings.warn(str(e))
43
  gr.Info(
44
  "Your user token does not have the necessary permissions to push to organizations."
45
  "Please check your OAuth permissions in https://huggingface.co/settings/connected-applications."
46
  "Update yout token permissions to include repo.write: https://huggingface.co/settings/tokens."
47
  )
48
+ return []
49
 
50
  return organizations
51