davidberenstein1957 HF staff commited on
Commit
d5fe7d9
·
unverified ·
2 Parent(s): 9bc9bf6 57b7e7b

Merge branch 'main' into main

Browse files
README.md CHANGED
@@ -104,6 +104,10 @@ Optionally, you can also push your datasets to Argilla for further curation by s
104
  - `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
105
  - `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
106
 
 
 
 
 
107
  ### Argilla integration
108
 
109
  Argilla is an open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
 
104
  - `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
105
  - `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
106
 
107
+ To save the generated datasets to a local directory instead of pushing them to the Hugging Face Hub, set the following environment variable:
108
+
109
+ - `SAVE_LOCAL_DIR`: The local directory to save the generated datasets to.
110
+
111
  ### Argilla integration
112
 
113
  Argilla is an open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
examples/fine-tune-deepseek-reasoning-sft.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/ollama-different-model-for-completion.py CHANGED
@@ -18,8 +18,8 @@ os.environ["OLLAMA_BASE_URL"] = (
18
  os.environ["MODEL"] = "llama3.2" # model for instruction generation
19
  os.environ["MODEL_COMPLETION"] = "llama3.2:1b" # model for completion generation
20
 
21
- os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.2-1B-Instruct" # tokenizer for instruction generation
22
- os.environ["TOKENIZER_ID_COMPLETION"] = "meta-llama/Llama-3.2-3B-Instruct" # tokenizer for completion generation
23
 
24
  os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template required for instruction generation
25
 
 
18
  os.environ["MODEL"] = "llama3.2" # model for instruction generation
19
  os.environ["MODEL_COMPLETION"] = "llama3.2:1b" # model for completion generation
20
 
21
+ os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.2-3B-Instruct" # tokenizer for instruction generation
22
+ os.environ["TOKENIZER_ID_COMPLETION"] = "meta-llama/Llama-3.2-1B-Instruct" # tokenizer for completion generation
23
 
24
  os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template required for instruction generation
25
 
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [project]
2
  name = "synthetic-dataset-generator"
3
- version = "0.1.7"
4
  description = "Build datasets using natural language"
5
  authors = [
6
  {name = "davidberenstein1957", email = "[email protected]"},
@@ -22,6 +22,7 @@ dependencies = [
22
  "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm,vision]>=1.5.0,<2.00",
23
  "gradio[oauth]>=5.4.0,<6.0.0",
24
  "gradio-huggingfacehub-search>=0.0.12,<1.0.0",
 
25
  "model2vec>=0.2.4,<1.0.0",
26
  "nltk>=3.9.1,<4.0.0",
27
  "pydantic>=2.10.5,<3.0.0",
 
1
  [project]
2
  name = "synthetic-dataset-generator"
3
+ version = "0.1.8"
4
  description = "Build datasets using natural language"
5
  authors = [
6
  {name = "davidberenstein1957", email = "[email protected]"},
 
22
  "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm,vision]>=1.5.0,<2.00",
23
  "gradio[oauth]>=5.4.0,<6.0.0",
24
  "gradio-huggingfacehub-search>=0.0.12,<1.0.0",
25
+ "huggingface-hub>=0.26.0,<0.28.0",
26
  "model2vec>=0.2.4,<1.0.0",
27
  "nltk>=3.9.1,<4.0.0",
28
  "pydantic>=2.10.5,<3.0.0",
src/synthetic_dataset_generator/app.py CHANGED
@@ -12,15 +12,16 @@ css = """
12
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
13
  button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
14
  button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
15
- .tabitem { border: 0; padding-inline: 0}
16
  .gallery-item {background: var(--background-fill-secondary); text-align: left}
17
- .table-wrap .tbody td { vertical-align: top }
18
- #system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
19
  .container {padding-inline: 0 !important}
20
- #sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
21
  .gradio-container { width: 100% !important; }
22
  .gradio-row { display: flex !important; flex-direction: row !important; }
23
  .gradio-column { flex: 1 !important; min-width: 0 !important; }
 
 
24
  """
25
 
26
  image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
 
12
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
13
  button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
14
  button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
15
+ .tabitem {border: 0; padding-inline: 0}
16
  .gallery-item {background: var(--background-fill-secondary); text-align: left}
17
+ .table-wrap .tbody td {vertical-align: top}
18
+ #system_prompt_examples {color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
19
  .container {padding-inline: 0 !important}
 
20
  .gradio-container { width: 100% !important; }
21
  .gradio-row { display: flex !important; flex-direction: row !important; }
22
  .gradio-column { flex: 1 !important; min-width: 0 !important; }
23
+ #sign_in_button {flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto;}
24
+ .datasets {height: 70px;}
25
  """
26
 
27
  image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -12,9 +12,13 @@ from huggingface_hub import HfApi, upload_file, repo_exists
12
  from unstructured.chunking.title import chunk_by_title
13
  from unstructured.partition.auto import partition
14
 
15
- from synthetic_dataset_generator.constants import MAX_NUM_ROWS
16
  from synthetic_dataset_generator.utils import get_argilla_client
17
 
 
 
 
 
18
 
19
  def validate_argilla_user_workspace_dataset(
20
  dataset_name: str,
 
12
  from unstructured.chunking.title import chunk_by_title
13
  from unstructured.partition.auto import partition
14
 
15
+ from synthetic_dataset_generator.constants import MAX_NUM_ROWS, SAVE_LOCAL_DIR
16
  from synthetic_dataset_generator.utils import get_argilla_client
17
 
18
+ if SAVE_LOCAL_DIR is not None:
19
+ import os
20
+ os.makedirs(SAVE_LOCAL_DIR, exist_ok=True)
21
+
22
 
23
  def validate_argilla_user_workspace_dataset(
24
  dataset_name: str,
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -1,5 +1,6 @@
1
  import ast
2
  import json
 
3
  import random
4
  import uuid
5
  from typing import Dict, List, Union
@@ -29,6 +30,7 @@ from synthetic_dataset_generator.constants import (
29
  DEFAULT_BATCH_SIZE,
30
  MODEL,
31
  MODEL_COMPLETION,
 
32
  SFT_AVAILABLE,
33
  )
34
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
@@ -309,7 +311,7 @@ def generate_dataset_from_seed(
309
  progress(
310
  step_progress * n_processed / num_rows,
311
  total=total_steps,
312
- desc="Generating questions",
313
  )
314
  remaining_rows = num_rows - n_processed
315
  batch_size = min(batch_size, remaining_rows)
@@ -368,7 +370,9 @@ def generate_dataset_from_seed(
368
  follow_up_instructions = list(
369
  follow_up_generator_instruction.process(inputs=conversations_batch)
370
  )
371
- for conv, follow_up in zip(conversations_batch, follow_up_instructions[0]):
 
 
372
  conv["messages"].append(
373
  {"role": "user", "content": follow_up["generation"]}
374
  )
@@ -506,7 +510,7 @@ def push_dataset(
506
  num_turns=num_turns,
507
  num_rows=num_rows,
508
  temperature=temperature,
509
- temperature_completion=temperature_completion
510
  )
511
  push_dataset_to_hub(
512
  dataframe=dataframe,
@@ -637,6 +641,45 @@ def push_dataset(
637
  return ""
638
 
639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
  def show_system_prompt_visibility():
641
  return {system_prompt: gr.Textbox(visible=True)}
642
 
@@ -672,6 +715,31 @@ def show_temperature_completion():
672
  return {temperature_completion: gr.Slider(value=0.9, visible=True)}
673
 
674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  ######################
676
  # Gradio UI
677
  ######################
@@ -781,7 +849,7 @@ with gr.Blocks() as app:
781
  )
782
  document_column = gr.Dropdown(
783
  label="Document Column",
784
- info="Select the document column to generate the RAG dataset",
785
  choices=["Load your data first in step 1."],
786
  value="Load your data first in step 1.",
787
  interactive=False,
@@ -852,10 +920,23 @@ with gr.Blocks() as app:
852
  btn_push_to_hub = gr.Button(
853
  "Push to Hub", variant="primary", scale=2
854
  )
 
 
 
855
  with gr.Column(scale=3):
 
 
 
 
 
 
 
 
 
 
856
  success_message = gr.Markdown(
857
- visible=True,
858
- min_height=100, # don't remove this otherwise progress is not visible
859
  )
860
  with gr.Accordion(
861
  "Customize your pipeline with distilabel",
@@ -953,6 +1034,9 @@ with gr.Blocks() as app:
953
  fn=validate_push_to_hub,
954
  inputs=[org_name, repo_name],
955
  outputs=[success_message],
 
 
 
956
  ).success(
957
  fn=hide_success_message,
958
  outputs=[success_message],
@@ -999,6 +1083,49 @@ with gr.Blocks() as app:
999
  outputs=[pipeline_code_ui],
1000
  )
1001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1002
  clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
1003
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
1004
  clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
@@ -1011,3 +1138,5 @@ with gr.Blocks() as app:
1011
  app.load(fn=get_org_dropdown, outputs=[org_name])
1012
  app.load(fn=get_random_repo_name, outputs=[repo_name])
1013
  app.load(fn=show_temperature_completion, outputs=[temperature_completion])
 
 
 
1
  import ast
2
  import json
3
+ import os
4
  import random
5
  import uuid
6
  from typing import Dict, List, Union
 
30
  DEFAULT_BATCH_SIZE,
31
  MODEL,
32
  MODEL_COMPLETION,
33
+ SAVE_LOCAL_DIR,
34
  SFT_AVAILABLE,
35
  )
36
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
 
311
  progress(
312
  step_progress * n_processed / num_rows,
313
  total=total_steps,
314
+ desc="Generating instructions",
315
  )
316
  remaining_rows = num_rows - n_processed
317
  batch_size = min(batch_size, remaining_rows)
 
370
  follow_up_instructions = list(
371
  follow_up_generator_instruction.process(inputs=conversations_batch)
372
  )
373
+ for conv, follow_up in zip(
374
+ conversations_batch, follow_up_instructions[0]
375
+ ):
376
  conv["messages"].append(
377
  {"role": "user", "content": follow_up["generation"]}
378
  )
 
510
  num_turns=num_turns,
511
  num_rows=num_rows,
512
  temperature=temperature,
513
+ temperature_completion=temperature_completion,
514
  )
515
  push_dataset_to_hub(
516
  dataframe=dataframe,
 
641
  return ""
642
 
643
 
644
+ def save_local(
645
+ repo_id: str,
646
+ file_paths: list[str],
647
+ input_type: str,
648
+ system_prompt: str,
649
+ document_column: str,
650
+ num_turns: int,
651
+ num_rows: int,
652
+ temperature: float,
653
+ repo_name: str,
654
+ temperature_completion: Union[float, None] = None,
655
+ ) -> pd.DataFrame:
656
+ if input_type == "prompt-input":
657
+ dataframe = _get_dataframe()
658
+ else:
659
+ dataframe, _ = load_dataset_file(
660
+ repo_id=repo_id,
661
+ file_paths=file_paths,
662
+ input_type=input_type,
663
+ num_rows=num_rows,
664
+ )
665
+ dataframe = generate_dataset(
666
+ input_type=input_type,
667
+ dataframe=dataframe,
668
+ system_prompt=system_prompt,
669
+ document_column=document_column,
670
+ num_turns=num_turns,
671
+ num_rows=num_rows,
672
+ temperature=temperature,
673
+ temperature_completion=temperature_completion,
674
+ )
675
+ local_dataset = Dataset.from_pandas(dataframe)
676
+ output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
677
+ output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
678
+ local_dataset.to_csv(output_csv, index=False)
679
+ local_dataset.to_json(output_json, index=False)
680
+ return output_csv, output_json
681
+
682
+
683
  def show_system_prompt_visibility():
684
  return {system_prompt: gr.Textbox(visible=True)}
685
 
 
715
  return {temperature_completion: gr.Slider(value=0.9, visible=True)}
716
 
717
 
718
+ def show_save_local_button():
719
+ return {btn_save_local: gr.Button(visible=True)}
720
+
721
+
722
+ def hide_save_local_button():
723
+ return {btn_save_local: gr.Button(visible=False)}
724
+
725
+
726
+ def show_save_local():
727
+ gr.update(success_message, min_height=0)
728
+ return {
729
+ csv_file: gr.File(visible=True),
730
+ json_file: gr.File(visible=True),
731
+ success_message: success_message
732
+ }
733
+
734
+ def hide_save_local():
735
+ gr.update(success_message, min_height=100)
736
+ return {
737
+ csv_file: gr.File(visible=False),
738
+ json_file: gr.File(visible=False),
739
+ success_message: success_message,
740
+ }
741
+
742
+
743
  ######################
744
  # Gradio UI
745
  ######################
 
849
  )
850
  document_column = gr.Dropdown(
851
  label="Document Column",
852
+ info="Select the document column to generate the chat data",
853
  choices=["Load your data first in step 1."],
854
  value="Load your data first in step 1.",
855
  interactive=False,
 
920
  btn_push_to_hub = gr.Button(
921
  "Push to Hub", variant="primary", scale=2
922
  )
923
+ btn_save_local = gr.Button(
924
+ "Save locally", variant="primary", scale=2, visible=False
925
+ )
926
  with gr.Column(scale=3):
927
+ csv_file = gr.File(
928
+ label="CSV",
929
+ elem_classes="datasets",
930
+ visible=False,
931
+ )
932
+ json_file = gr.File(
933
+ label="JSON",
934
+ elem_classes="datasets",
935
+ visible=False,
936
+ )
937
  success_message = gr.Markdown(
938
+ visible=False,
939
+ min_height=0 # don't remove this otherwise progress is not visible
940
  )
941
  with gr.Accordion(
942
  "Customize your pipeline with distilabel",
 
1034
  fn=validate_push_to_hub,
1035
  inputs=[org_name, repo_name],
1036
  outputs=[success_message],
1037
+ ).success(
1038
+ fn=hide_save_local,
1039
+ outputs=[csv_file, json_file, success_message],
1040
  ).success(
1041
  fn=hide_success_message,
1042
  outputs=[success_message],
 
1083
  outputs=[pipeline_code_ui],
1084
  )
1085
 
1086
+ btn_save_local.click(
1087
+ fn=hide_success_message,
1088
+ outputs=[success_message],
1089
+ ).success(
1090
+ fn=hide_pipeline_code_visibility,
1091
+ inputs=[],
1092
+ outputs=[pipeline_code_ui],
1093
+ ).success(
1094
+ fn=show_save_local,
1095
+ inputs=[],
1096
+ outputs=[csv_file, json_file, success_message],
1097
+ ).success(
1098
+ save_local,
1099
+ inputs=[
1100
+ search_in,
1101
+ file_in,
1102
+ input_type,
1103
+ system_prompt,
1104
+ document_column,
1105
+ num_turns,
1106
+ num_rows,
1107
+ temperature,
1108
+ repo_name,
1109
+ temperature_completion,
1110
+ ],
1111
+ outputs=[csv_file, json_file],
1112
+ ).success(
1113
+ fn=generate_pipeline_code,
1114
+ inputs=[
1115
+ search_in,
1116
+ input_type,
1117
+ system_prompt,
1118
+ document_column,
1119
+ num_turns,
1120
+ num_rows,
1121
+ ],
1122
+ outputs=[pipeline_code],
1123
+ ).success(
1124
+ fn=show_pipeline_code_visibility,
1125
+ inputs=[],
1126
+ outputs=[pipeline_code_ui],
1127
+ )
1128
+
1129
  clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
1130
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
1131
  clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
 
1138
  app.load(fn=get_org_dropdown, outputs=[org_name])
1139
  app.load(fn=get_random_repo_name, outputs=[repo_name])
1140
  app.load(fn=show_temperature_completion, outputs=[temperature_completion])
1141
+ if SAVE_LOCAL_DIR is not None:
1142
+ app.load(fn=show_save_local_button, outputs=btn_save_local)
src/synthetic_dataset_generator/apps/rag.py CHANGED
@@ -24,7 +24,12 @@ from synthetic_dataset_generator.apps.base import (
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION
 
 
 
 
 
28
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
  from synthetic_dataset_generator.pipelines.embeddings import (
30
  get_embeddings,
@@ -156,7 +161,7 @@ def generate_dataset(
156
  is_sample=is_sample,
157
  )
158
  response_generator = get_response_generator(
159
- temperature = temperature_completion or temperature , is_sample=is_sample
160
  )
161
  if reranking:
162
  reranking_generator = get_sentence_pair_generator(
@@ -486,6 +491,49 @@ def push_dataset(
486
  return ""
487
 
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  def show_system_prompt_visibility():
490
  return {system_prompt: gr.Textbox(visible=True)}
491
 
@@ -521,6 +569,32 @@ def show_temperature_completion():
521
  return {temperature_completion: gr.Slider(value=0.9, visible=True)}
522
 
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  ######################
525
  # Gradio UI
526
  ######################
@@ -675,10 +749,23 @@ with gr.Blocks() as app:
675
  scale=1,
676
  )
677
  btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
 
 
 
678
  with gr.Column(scale=3):
 
 
 
 
 
 
 
 
 
 
679
  success_message = gr.Markdown(
680
- visible=True,
681
- min_height=100, # don't remove this otherwise progress is not visible
682
  )
683
  with gr.Accordion(
684
  "Customize your pipeline with distilabel",
@@ -776,6 +863,9 @@ with gr.Blocks() as app:
776
  fn=validate_push_to_hub,
777
  inputs=[org_name, repo_name],
778
  outputs=[success_message],
 
 
 
779
  ).success(
780
  fn=hide_success_message,
781
  outputs=[success_message],
@@ -822,6 +912,49 @@ with gr.Blocks() as app:
822
  outputs=[pipeline_code_ui],
823
  )
824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
  clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
826
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
827
  clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
@@ -835,3 +968,5 @@ with gr.Blocks() as app:
835
  app.load(fn=get_org_dropdown, outputs=[org_name])
836
  app.load(fn=get_random_repo_name, outputs=[repo_name])
837
  app.load(fn=show_temperature_completion, outputs=[temperature_completion])
 
 
 
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
+ from synthetic_dataset_generator.constants import (
28
+ DEFAULT_BATCH_SIZE,
29
+ MODEL,
30
+ MODEL_COMPLETION,
31
+ SAVE_LOCAL_DIR,
32
+ )
33
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
34
  from synthetic_dataset_generator.pipelines.embeddings import (
35
  get_embeddings,
 
161
  is_sample=is_sample,
162
  )
163
  response_generator = get_response_generator(
164
+ temperature=temperature_completion or temperature, is_sample=is_sample
165
  )
166
  if reranking:
167
  reranking_generator = get_sentence_pair_generator(
 
491
  return ""
492
 
493
 
494
+ def save_local(
495
+ repo_id: str,
496
+ file_paths: list[str],
497
+ input_type: str,
498
+ system_prompt: str,
499
+ document_column: str,
500
+ retrieval_reranking: list[str],
501
+ num_rows: int,
502
+ temperature: float,
503
+ repo_name: str,
504
+ temperature_completion: float,
505
+ ) -> pd.DataFrame:
506
+ retrieval = "Retrieval" in retrieval_reranking
507
+ reranking = "Reranking" in retrieval_reranking
508
+
509
+ if input_type == "prompt-input":
510
+ dataframe = pd.DataFrame(columns=["context", "question", "response"])
511
+ else:
512
+ dataframe, _ = load_dataset_file(
513
+ repo_id=repo_id,
514
+ file_paths=file_paths,
515
+ input_type=input_type,
516
+ num_rows=num_rows,
517
+ )
518
+ dataframe = generate_dataset(
519
+ input_type=input_type,
520
+ dataframe=dataframe,
521
+ system_prompt=system_prompt,
522
+ document_column=document_column,
523
+ retrieval=retrieval,
524
+ reranking=reranking,
525
+ num_rows=num_rows,
526
+ temperature=temperature,
527
+ temperature_completion=temperature_completion,
528
+ )
529
+ local_dataset = Dataset.from_pandas(dataframe)
530
+ output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
531
+ output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
532
+ local_dataset.to_csv(output_csv, index=False)
533
+ local_dataset.to_json(output_json, index=False)
534
+ return output_csv, output_json
535
+
536
+
537
  def show_system_prompt_visibility():
538
  return {system_prompt: gr.Textbox(visible=True)}
539
 
 
569
  return {temperature_completion: gr.Slider(value=0.9, visible=True)}
570
 
571
 
572
+ def show_save_local_button():
573
+ return {btn_save_local: gr.Button(visible=True)}
574
+
575
+
576
+ def hide_save_local_button():
577
+ return {btn_save_local: gr.Button(visible=False)}
578
+
579
+
580
+ def show_save_local():
581
+ gr.update(success_message, min_height=0)
582
+ return {
583
+ csv_file: gr.File(visible=True),
584
+ json_file: gr.File(visible=True),
585
+ success_message: success_message,
586
+ }
587
+
588
+
589
+ def hide_save_local():
590
+ gr.update(success_message, min_height=100)
591
+ return {
592
+ csv_file: gr.File(visible=False),
593
+ json_file: gr.File(visible=False),
594
+ success_message: success_message,
595
+ }
596
+
597
+
598
  ######################
599
  # Gradio UI
600
  ######################
 
749
  scale=1,
750
  )
751
  btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
752
+ btn_save_local = gr.Button(
753
+ "Save locally", variant="primary", scale=2, visible=False
754
+ )
755
  with gr.Column(scale=3):
756
+ csv_file = gr.File(
757
+ label="CSV",
758
+ elem_classes="datasets",
759
+ visible=False,
760
+ )
761
+ json_file = gr.File(
762
+ label="JSON",
763
+ elem_classes="datasets",
764
+ visible=False,
765
+ )
766
  success_message = gr.Markdown(
767
+ visible=False,
768
+ min_height=0, # don't remove this otherwise progress is not visible
769
  )
770
  with gr.Accordion(
771
  "Customize your pipeline with distilabel",
 
863
  fn=validate_push_to_hub,
864
  inputs=[org_name, repo_name],
865
  outputs=[success_message],
866
+ ).success(
867
+ fn=hide_save_local,
868
+ outputs=[csv_file, json_file, success_message],
869
  ).success(
870
  fn=hide_success_message,
871
  outputs=[success_message],
 
912
  outputs=[pipeline_code_ui],
913
  )
914
 
915
+ btn_save_local.click(
916
+ fn=hide_success_message,
917
+ outputs=[success_message],
918
+ ).success(
919
+ fn=hide_pipeline_code_visibility,
920
+ inputs=[],
921
+ outputs=[pipeline_code_ui],
922
+ ).success(
923
+ fn=show_save_local,
924
+ inputs=[],
925
+ outputs=[csv_file, json_file, success_message],
926
+ ).success(
927
+ save_local,
928
+ inputs=[
929
+ search_in,
930
+ file_in,
931
+ input_type,
932
+ system_prompt,
933
+ document_column,
934
+ retrieval_reranking,
935
+ num_rows,
936
+ temperature,
937
+ repo_name,
938
+ temperature_completion,
939
+ ],
940
+ outputs=[csv_file, json_file],
941
+ ).success(
942
+ fn=generate_pipeline_code,
943
+ inputs=[
944
+ search_in,
945
+ input_type,
946
+ system_prompt,
947
+ document_column,
948
+ retrieval_reranking,
949
+ num_rows,
950
+ ],
951
+ outputs=[pipeline_code],
952
+ ).success(
953
+ fn=show_pipeline_code_visibility,
954
+ inputs=[],
955
+ outputs=[pipeline_code_ui],
956
+ )
957
+
958
  clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
959
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
960
  clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
 
968
  app.load(fn=get_org_dropdown, outputs=[org_name])
969
  app.load(fn=get_random_repo_name, outputs=[repo_name])
970
  app.load(fn=show_temperature_completion, outputs=[temperature_completion])
971
+ if SAVE_LOCAL_DIR is not None:
972
+ app.load(fn=show_save_local_button, outputs=btn_save_local)
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  import random
3
  import uuid
4
  from typing import List, Union
@@ -19,7 +20,7 @@ from synthetic_dataset_generator.apps.base import (
19
  validate_argilla_user_workspace_dataset,
20
  validate_push_to_hub,
21
  )
22
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
23
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
24
  from synthetic_dataset_generator.pipelines.embeddings import (
25
  get_embeddings,
@@ -195,7 +196,7 @@ def generate_dataset(
195
  set(
196
  label.lower().strip()
197
  for label in x
198
- if label.lower().strip() in labels
199
  )
200
  )
201
  else:
@@ -406,6 +407,33 @@ def push_dataset(
406
  return ""
407
 
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  def validate_input_labels(labels: List[str]) -> List[str]:
410
  if (
411
  not labels
@@ -425,6 +453,32 @@ def hide_pipeline_code_visibility():
425
  return {pipeline_code_ui: gr.Accordion(visible=False)}
426
 
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  ######################
429
  # Gradio UI
430
  ######################
@@ -544,10 +598,23 @@ with gr.Blocks() as app:
544
  scale=1,
545
  )
546
  btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
 
 
 
547
  with gr.Column(scale=3):
 
 
 
 
 
 
 
 
 
 
548
  success_message = gr.Markdown(
549
- visible=True,
550
- min_height=100, # don't remove this otherwise progress is not visible
551
  )
552
  with gr.Accordion(
553
  "Customize your pipeline with distilabel",
@@ -600,6 +667,9 @@ with gr.Blocks() as app:
600
  fn=validate_input_labels,
601
  inputs=[labels],
602
  outputs=[labels],
 
 
 
603
  ).success(
604
  fn=hide_success_message,
605
  outputs=[success_message],
@@ -644,6 +714,47 @@ with gr.Blocks() as app:
644
  outputs=[pipeline_code_ui],
645
  )
646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  gr.on(
648
  triggers=[clear_btn_part.click, clear_btn_full.click],
649
  fn=lambda _: (
@@ -660,3 +771,5 @@ with gr.Blocks() as app:
660
  app.load(fn=swap_visibility, outputs=main_ui)
661
  app.load(fn=get_org_dropdown, outputs=[org_name])
662
  app.load(fn=get_random_repo_name, outputs=[repo_name])
 
 
 
1
  import json
2
+ import os
3
  import random
4
  import uuid
5
  from typing import List, Union
 
20
  validate_argilla_user_workspace_dataset,
21
  validate_push_to_hub,
22
  )
23
+ from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, SAVE_LOCAL_DIR
24
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
25
  from synthetic_dataset_generator.pipelines.embeddings import (
26
  get_embeddings,
 
196
  set(
197
  label.lower().strip()
198
  for label in x
199
+ if isinstance(label, str) and label.lower().strip() in labels
200
  )
201
  )
202
  else:
 
407
  return ""
408
 
409
 
410
+ def save_local(
411
+ system_prompt: str,
412
+ difficulty: str,
413
+ clarity: str,
414
+ labels: List[str],
415
+ multi_label: bool,
416
+ num_rows: int,
417
+ temperature: float,
418
+ repo_name: str,
419
+ ) -> pd.DataFrame:
420
+ dataframe = generate_dataset(
421
+ system_prompt=system_prompt,
422
+ difficulty=difficulty,
423
+ clarity=clarity,
424
+ multi_label=multi_label,
425
+ labels=labels,
426
+ num_rows=num_rows,
427
+ temperature=temperature,
428
+ )
429
+ local_dataset = Dataset.from_pandas(dataframe)
430
+ output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
431
+ output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
432
+ local_dataset.to_csv(output_csv, index=False)
433
+ local_dataset.to_json(output_json, index=False)
434
+ return output_csv, output_json
435
+
436
+
437
  def validate_input_labels(labels: List[str]) -> List[str]:
438
  if (
439
  not labels
 
453
  return {pipeline_code_ui: gr.Accordion(visible=False)}
454
 
455
 
456
+ def show_save_local_button():
457
+ return {btn_save_local: gr.Button(visible=True)}
458
+
459
+
460
+ def hide_save_local_button():
461
+ return {btn_save_local: gr.Button(visible=False)}
462
+
463
+
464
+ def show_save_local():
465
+ gr.update(success_message, min_height=0)
466
+ return {
467
+ csv_file: gr.File(visible=True),
468
+ json_file: gr.File(visible=True),
469
+ success_message: success_message,
470
+ }
471
+
472
+
473
+ def hide_save_local():
474
+ gr.update(success_message, min_height=100)
475
+ return {
476
+ csv_file: gr.File(visible=False),
477
+ json_file: gr.File(visible=False),
478
+ success_message: success_message,
479
+ }
480
+
481
+
482
  ######################
483
  # Gradio UI
484
  ######################
 
598
  scale=1,
599
  )
600
  btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
601
+ btn_save_local = gr.Button(
602
+ "Save locally", variant="primary", scale=2, visible=False
603
+ )
604
  with gr.Column(scale=3):
605
+ csv_file = gr.File(
606
+ label="CSV",
607
+ elem_classes="datasets",
608
+ visible=False,
609
+ )
610
+ json_file = gr.File(
611
+ label="JSON",
612
+ elem_classes="datasets",
613
+ visible=False,
614
+ )
615
  success_message = gr.Markdown(
616
+ visible=False,
617
+ min_height=0, # don't remove this otherwise progress is not visible
618
  )
619
  with gr.Accordion(
620
  "Customize your pipeline with distilabel",
 
667
  fn=validate_input_labels,
668
  inputs=[labels],
669
  outputs=[labels],
670
+ ).success(
671
+ fn=hide_save_local,
672
+ outputs=[csv_file, json_file, success_message],
673
  ).success(
674
  fn=hide_success_message,
675
  outputs=[success_message],
 
714
  outputs=[pipeline_code_ui],
715
  )
716
 
717
+ btn_save_local.click(
718
+ fn=hide_success_message,
719
+ outputs=[success_message],
720
+ ).success(
721
+ fn=hide_pipeline_code_visibility,
722
+ inputs=[],
723
+ outputs=[pipeline_code_ui],
724
+ ).success(
725
+ fn=show_save_local,
726
+ inputs=[],
727
+ outputs=[csv_file, json_file, success_message],
728
+ ).success(
729
+ save_local,
730
+ inputs=[
731
+ system_prompt,
732
+ difficulty,
733
+ clarity,
734
+ labels,
735
+ multi_label,
736
+ num_rows,
737
+ temperature,
738
+ repo_name,
739
+ ],
740
+ outputs=[csv_file, json_file],
741
+ ).success(
742
+ fn=generate_pipeline_code,
743
+ inputs=[
744
+ system_prompt,
745
+ difficulty,
746
+ clarity,
747
+ labels,
748
+ multi_label,
749
+ num_rows,
750
+ ],
751
+ outputs=[pipeline_code],
752
+ ).success(
753
+ fn=show_pipeline_code_visibility,
754
+ inputs=[],
755
+ outputs=[pipeline_code_ui],
756
+ )
757
+
758
  gr.on(
759
  triggers=[clear_btn_part.click, clear_btn_full.click],
760
  fn=lambda _: (
 
771
  app.load(fn=swap_visibility, outputs=main_ui)
772
  app.load(fn=get_org_dropdown, outputs=[org_name])
773
  app.load(fn=get_random_repo_name, outputs=[repo_name])
774
+ if SAVE_LOCAL_DIR is not None:
775
+ app.load(fn=show_save_local_button, outputs=btn_save_local)
src/synthetic_dataset_generator/constants.py CHANGED
@@ -8,6 +8,9 @@ MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
8
  MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
9
  DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
10
 
 
 
 
11
  # Models
12
  MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
13
  TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
 
8
  MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
9
  DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
10
 
11
+ # Directory to locally save the generated data
12
+ SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)
13
+
14
  # Models
15
  MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
16
  TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
src/synthetic_dataset_generator/pipelines/base.py CHANGED
@@ -87,10 +87,17 @@ def _get_llm(
87
  ):
88
  model = MODEL_COMPLETION if is_completion else MODEL
89
  tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
90
- if OPENAI_BASE_URL:
 
 
 
 
 
 
 
91
  llm = OpenAILLM(
92
  model=model,
93
- base_url=OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
94
  api_key=_get_next_api_key(),
95
  structured_output=structured_output,
96
  **kwargs,
@@ -103,7 +110,7 @@ def _get_llm(
103
  del kwargs["generation_kwargs"]["stop_sequences"]
104
  if "do_sample" in kwargs["generation_kwargs"]:
105
  del kwargs["generation_kwargs"]["do_sample"]
106
- elif OLLAMA_BASE_URL:
107
  if "generation_kwargs" in kwargs:
108
  if "max_new_tokens" in kwargs["generation_kwargs"]:
109
  kwargs["generation_kwargs"]["num_predict"] = kwargs[
@@ -123,32 +130,28 @@ def _get_llm(
123
  kwargs["generation_kwargs"]["options"] = options
124
  llm = OllamaLLM(
125
  model=model,
126
- host=OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
127
  tokenizer_id=tokenizer_id,
128
  use_magpie_template=use_magpie_template,
129
  structured_output=structured_output,
130
  **kwargs,
131
  )
132
- elif HUGGINGFACE_BASE_URL:
133
  kwargs["generation_kwargs"]["do_sample"] = True
134
  llm = InferenceEndpointsLLM(
135
  api_key=_get_next_api_key(),
136
- base_url=(
137
- HUGGINGFACE_BASE_URL_COMPLETION
138
- if is_completion
139
- else HUGGINGFACE_BASE_URL
140
- ),
141
  tokenizer_id=tokenizer_id,
142
  use_magpie_template=use_magpie_template,
143
  structured_output=structured_output,
144
  **kwargs,
145
  )
146
- elif VLLM_BASE_URL:
147
  if "generation_kwargs" in kwargs:
148
  if "do_sample" in kwargs["generation_kwargs"]:
149
  del kwargs["generation_kwargs"]["do_sample"]
150
  llm = ClientvLLM(
151
- base_url=VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
152
  model=model,
153
  tokenizer=tokenizer_id,
154
  api_key=_get_next_api_key(),
 
87
  ):
88
  model = MODEL_COMPLETION if is_completion else MODEL
89
  tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
90
+ base_urls = {
91
+ "openai": OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
92
+ "ollama": OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
93
+ "huggingface": HUGGINGFACE_BASE_URL_COMPLETION if is_completion else HUGGINGFACE_BASE_URL,
94
+ "vllm": VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
95
+ }
96
+
97
+ if base_urls["openai"]:
98
  llm = OpenAILLM(
99
  model=model,
100
+ base_url=base_urls["openai"],
101
  api_key=_get_next_api_key(),
102
  structured_output=structured_output,
103
  **kwargs,
 
110
  del kwargs["generation_kwargs"]["stop_sequences"]
111
  if "do_sample" in kwargs["generation_kwargs"]:
112
  del kwargs["generation_kwargs"]["do_sample"]
113
+ elif base_urls["ollama"]:
114
  if "generation_kwargs" in kwargs:
115
  if "max_new_tokens" in kwargs["generation_kwargs"]:
116
  kwargs["generation_kwargs"]["num_predict"] = kwargs[
 
130
  kwargs["generation_kwargs"]["options"] = options
131
  llm = OllamaLLM(
132
  model=model,
133
+ host=base_urls["ollama"],
134
  tokenizer_id=tokenizer_id,
135
  use_magpie_template=use_magpie_template,
136
  structured_output=structured_output,
137
  **kwargs,
138
  )
139
+ elif base_urls["huggingface"]:
140
  kwargs["generation_kwargs"]["do_sample"] = True
141
  llm = InferenceEndpointsLLM(
142
  api_key=_get_next_api_key(),
143
+ base_url=base_urls["huggingface"],
 
 
 
 
144
  tokenizer_id=tokenizer_id,
145
  use_magpie_template=use_magpie_template,
146
  structured_output=structured_output,
147
  **kwargs,
148
  )
149
+ elif base_urls["vllm"]:
150
  if "generation_kwargs" in kwargs:
151
  if "do_sample" in kwargs["generation_kwargs"]:
152
  del kwargs["generation_kwargs"]["do_sample"]
153
  llm = ClientvLLM(
154
+ base_url=base_urls["vllm"],
155
  model=model,
156
  tokenizer=tokenizer_id,
157
  api_key=_get_next_api_key(),
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -109,7 +109,7 @@ def get_labeller_generator(system_prompt: str, labels: List[str], multi_label: b
109
  "temperature": 0.01,
110
  "max_new_tokens": MAX_NUM_TOKENS,
111
  }
112
- llm = _get_llm(is_completion=True, generation_kwargs=generation_kwargs)
113
  labeller_generator = TextClassification(
114
  llm=llm,
115
  context=system_prompt,
 
109
  "temperature": 0.01,
110
  "max_new_tokens": MAX_NUM_TOKENS,
111
  }
112
+ llm = _get_llm(generation_kwargs=generation_kwargs)
113
  labeller_generator = TextClassification(
114
  llm=llm,
115
  context=system_prompt,