Merge branch 'main' into main
Browse files- README.md +4 -0
- examples/fine-tune-deepseek-reasoning-sft.ipynb +0 -0
- examples/ollama-different-model-for-completion.py +2 -2
- pyproject.toml +2 -1
- src/synthetic_dataset_generator/app.py +5 -4
- src/synthetic_dataset_generator/apps/base.py +5 -1
- src/synthetic_dataset_generator/apps/chat.py +135 -6
- src/synthetic_dataset_generator/apps/rag.py +139 -4
- src/synthetic_dataset_generator/apps/textcat.py +117 -4
- src/synthetic_dataset_generator/constants.py +3 -0
- src/synthetic_dataset_generator/pipelines/base.py +15 -12
- src/synthetic_dataset_generator/pipelines/textcat.py +1 -1
README.md
CHANGED
@@ -104,6 +104,10 @@ Optionally, you can also push your datasets to Argilla for further curation by s
|
|
104 |
- `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
|
105 |
- `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
|
106 |
|
|
|
|
|
|
|
|
|
107 |
### Argilla integration
|
108 |
|
109 |
Argilla is an open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
|
|
|
104 |
- `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
|
105 |
- `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
|
106 |
|
107 |
+
To save the generated datasets to a local directory instead of pushing them to the Hugging Face Hub, set the following environment variable:
|
108 |
+
|
109 |
+
- `SAVE_LOCAL_DIR`: The local directory to save the generated datasets to.
|
110 |
+
|
111 |
### Argilla integration
|
112 |
|
113 |
Argilla is an open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
|
examples/fine-tune-deepseek-reasoning-sft.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/ollama-different-model-for-completion.py
CHANGED
@@ -18,8 +18,8 @@ os.environ["OLLAMA_BASE_URL"] = (
|
|
18 |
os.environ["MODEL"] = "llama3.2" # model for instruction generation
|
19 |
os.environ["MODEL_COMPLETION"] = "llama3.2:1b" # model for completion generation
|
20 |
|
21 |
-
os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.2-
|
22 |
-
os.environ["TOKENIZER_ID_COMPLETION"] = "meta-llama/Llama-3.2-
|
23 |
|
24 |
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template required for instruction generation
|
25 |
|
|
|
18 |
os.environ["MODEL"] = "llama3.2" # model for instruction generation
|
19 |
os.environ["MODEL_COMPLETION"] = "llama3.2:1b" # model for completion generation
|
20 |
|
21 |
+
os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.2-3B-Instruct" # tokenizer for instruction generation
|
22 |
+
os.environ["TOKENIZER_ID_COMPLETION"] = "meta-llama/Llama-3.2-1B-Instruct" # tokenizer for completion generation
|
23 |
|
24 |
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template required for instruction generation
|
25 |
|
pyproject.toml
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
[project]
|
2 |
name = "synthetic-dataset-generator"
|
3 |
-
version = "0.1.
|
4 |
description = "Build datasets using natural language"
|
5 |
authors = [
|
6 |
{name = "davidberenstein1957", email = "[email protected]"},
|
@@ -22,6 +22,7 @@ dependencies = [
|
|
22 |
"distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm,vision]>=1.5.0,<2.00",
|
23 |
"gradio[oauth]>=5.4.0,<6.0.0",
|
24 |
"gradio-huggingfacehub-search>=0.0.12,<1.0.0",
|
|
|
25 |
"model2vec>=0.2.4,<1.0.0",
|
26 |
"nltk>=3.9.1,<4.0.0",
|
27 |
"pydantic>=2.10.5,<3.0.0",
|
|
|
1 |
[project]
|
2 |
name = "synthetic-dataset-generator"
|
3 |
+
version = "0.1.8"
|
4 |
description = "Build datasets using natural language"
|
5 |
authors = [
|
6 |
{name = "davidberenstein1957", email = "[email protected]"},
|
|
|
22 |
"distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm,vision]>=1.5.0,<2.00",
|
23 |
"gradio[oauth]>=5.4.0,<6.0.0",
|
24 |
"gradio-huggingfacehub-search>=0.0.12,<1.0.0",
|
25 |
+
"huggingface-hub>=0.26.0,<0.28.0",
|
26 |
"model2vec>=0.2.4,<1.0.0",
|
27 |
"nltk>=3.9.1,<4.0.0",
|
28 |
"pydantic>=2.10.5,<3.0.0",
|
src/synthetic_dataset_generator/app.py
CHANGED
@@ -12,15 +12,16 @@ css = """
|
|
12 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
13 |
button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
|
14 |
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
|
15 |
-
.tabitem {
|
16 |
.gallery-item {background: var(--background-fill-secondary); text-align: left}
|
17 |
-
.table-wrap .tbody td {
|
18 |
-
#system_prompt_examples {
|
19 |
.container {padding-inline: 0 !important}
|
20 |
-
#sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
|
21 |
.gradio-container { width: 100% !important; }
|
22 |
.gradio-row { display: flex !important; flex-direction: row !important; }
|
23 |
.gradio-column { flex: 1 !important; min-width: 0 !important; }
|
|
|
|
|
24 |
"""
|
25 |
|
26 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
|
|
12 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
13 |
button[role="tab"][aria-selected="true"] { border: 0; background: var(--button-primary-background-fill); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
|
14 |
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill); background: var(var(--button-primary-background-fill-hover))}
|
15 |
+
.tabitem {border: 0; padding-inline: 0}
|
16 |
.gallery-item {background: var(--background-fill-secondary); text-align: left}
|
17 |
+
.table-wrap .tbody td {vertical-align: top}
|
18 |
+
#system_prompt_examples {color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
19 |
.container {padding-inline: 0 !important}
|
|
|
20 |
.gradio-container { width: 100% !important; }
|
21 |
.gradio-row { display: flex !important; flex-direction: row !important; }
|
22 |
.gradio-column { flex: 1 !important; min-width: 0 !important; }
|
23 |
+
#sign_in_button {flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto;}
|
24 |
+
.datasets {height: 70px;}
|
25 |
"""
|
26 |
|
27 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
src/synthetic_dataset_generator/apps/base.py
CHANGED
@@ -12,9 +12,13 @@ from huggingface_hub import HfApi, upload_file, repo_exists
|
|
12 |
from unstructured.chunking.title import chunk_by_title
|
13 |
from unstructured.partition.auto import partition
|
14 |
|
15 |
-
from synthetic_dataset_generator.constants import MAX_NUM_ROWS
|
16 |
from synthetic_dataset_generator.utils import get_argilla_client
|
17 |
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def validate_argilla_user_workspace_dataset(
|
20 |
dataset_name: str,
|
|
|
12 |
from unstructured.chunking.title import chunk_by_title
|
13 |
from unstructured.partition.auto import partition
|
14 |
|
15 |
+
from synthetic_dataset_generator.constants import MAX_NUM_ROWS, SAVE_LOCAL_DIR
|
16 |
from synthetic_dataset_generator.utils import get_argilla_client
|
17 |
|
18 |
+
if SAVE_LOCAL_DIR is not None:
|
19 |
+
import os
|
20 |
+
os.makedirs(SAVE_LOCAL_DIR, exist_ok=True)
|
21 |
+
|
22 |
|
23 |
def validate_argilla_user_workspace_dataset(
|
24 |
dataset_name: str,
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import ast
|
2 |
import json
|
|
|
3 |
import random
|
4 |
import uuid
|
5 |
from typing import Dict, List, Union
|
@@ -29,6 +30,7 @@ from synthetic_dataset_generator.constants import (
|
|
29 |
DEFAULT_BATCH_SIZE,
|
30 |
MODEL,
|
31 |
MODEL_COMPLETION,
|
|
|
32 |
SFT_AVAILABLE,
|
33 |
)
|
34 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
@@ -309,7 +311,7 @@ def generate_dataset_from_seed(
|
|
309 |
progress(
|
310 |
step_progress * n_processed / num_rows,
|
311 |
total=total_steps,
|
312 |
-
desc="Generating
|
313 |
)
|
314 |
remaining_rows = num_rows - n_processed
|
315 |
batch_size = min(batch_size, remaining_rows)
|
@@ -368,7 +370,9 @@ def generate_dataset_from_seed(
|
|
368 |
follow_up_instructions = list(
|
369 |
follow_up_generator_instruction.process(inputs=conversations_batch)
|
370 |
)
|
371 |
-
for conv, follow_up in zip(
|
|
|
|
|
372 |
conv["messages"].append(
|
373 |
{"role": "user", "content": follow_up["generation"]}
|
374 |
)
|
@@ -506,7 +510,7 @@ def push_dataset(
|
|
506 |
num_turns=num_turns,
|
507 |
num_rows=num_rows,
|
508 |
temperature=temperature,
|
509 |
-
temperature_completion=temperature_completion
|
510 |
)
|
511 |
push_dataset_to_hub(
|
512 |
dataframe=dataframe,
|
@@ -637,6 +641,45 @@ def push_dataset(
|
|
637 |
return ""
|
638 |
|
639 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
def show_system_prompt_visibility():
|
641 |
return {system_prompt: gr.Textbox(visible=True)}
|
642 |
|
@@ -672,6 +715,31 @@ def show_temperature_completion():
|
|
672 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
673 |
|
674 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
675 |
######################
|
676 |
# Gradio UI
|
677 |
######################
|
@@ -781,7 +849,7 @@ with gr.Blocks() as app:
|
|
781 |
)
|
782 |
document_column = gr.Dropdown(
|
783 |
label="Document Column",
|
784 |
-
info="Select the document column to generate the
|
785 |
choices=["Load your data first in step 1."],
|
786 |
value="Load your data first in step 1.",
|
787 |
interactive=False,
|
@@ -852,10 +920,23 @@ with gr.Blocks() as app:
|
|
852 |
btn_push_to_hub = gr.Button(
|
853 |
"Push to Hub", variant="primary", scale=2
|
854 |
)
|
|
|
|
|
|
|
855 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
856 |
success_message = gr.Markdown(
|
857 |
-
visible=
|
858 |
-
min_height=
|
859 |
)
|
860 |
with gr.Accordion(
|
861 |
"Customize your pipeline with distilabel",
|
@@ -953,6 +1034,9 @@ with gr.Blocks() as app:
|
|
953 |
fn=validate_push_to_hub,
|
954 |
inputs=[org_name, repo_name],
|
955 |
outputs=[success_message],
|
|
|
|
|
|
|
956 |
).success(
|
957 |
fn=hide_success_message,
|
958 |
outputs=[success_message],
|
@@ -999,6 +1083,49 @@ with gr.Blocks() as app:
|
|
999 |
outputs=[pipeline_code_ui],
|
1000 |
)
|
1001 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1002 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
1003 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
1004 |
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
@@ -1011,3 +1138,5 @@ with gr.Blocks() as app:
|
|
1011 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
1012 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
1013 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
|
|
|
|
|
1 |
import ast
|
2 |
import json
|
3 |
+
import os
|
4 |
import random
|
5 |
import uuid
|
6 |
from typing import Dict, List, Union
|
|
|
30 |
DEFAULT_BATCH_SIZE,
|
31 |
MODEL,
|
32 |
MODEL_COMPLETION,
|
33 |
+
SAVE_LOCAL_DIR,
|
34 |
SFT_AVAILABLE,
|
35 |
)
|
36 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
|
|
311 |
progress(
|
312 |
step_progress * n_processed / num_rows,
|
313 |
total=total_steps,
|
314 |
+
desc="Generating instructions",
|
315 |
)
|
316 |
remaining_rows = num_rows - n_processed
|
317 |
batch_size = min(batch_size, remaining_rows)
|
|
|
370 |
follow_up_instructions = list(
|
371 |
follow_up_generator_instruction.process(inputs=conversations_batch)
|
372 |
)
|
373 |
+
for conv, follow_up in zip(
|
374 |
+
conversations_batch, follow_up_instructions[0]
|
375 |
+
):
|
376 |
conv["messages"].append(
|
377 |
{"role": "user", "content": follow_up["generation"]}
|
378 |
)
|
|
|
510 |
num_turns=num_turns,
|
511 |
num_rows=num_rows,
|
512 |
temperature=temperature,
|
513 |
+
temperature_completion=temperature_completion,
|
514 |
)
|
515 |
push_dataset_to_hub(
|
516 |
dataframe=dataframe,
|
|
|
641 |
return ""
|
642 |
|
643 |
|
644 |
+
def save_local(
|
645 |
+
repo_id: str,
|
646 |
+
file_paths: list[str],
|
647 |
+
input_type: str,
|
648 |
+
system_prompt: str,
|
649 |
+
document_column: str,
|
650 |
+
num_turns: int,
|
651 |
+
num_rows: int,
|
652 |
+
temperature: float,
|
653 |
+
repo_name: str,
|
654 |
+
temperature_completion: Union[float, None] = None,
|
655 |
+
) -> pd.DataFrame:
|
656 |
+
if input_type == "prompt-input":
|
657 |
+
dataframe = _get_dataframe()
|
658 |
+
else:
|
659 |
+
dataframe, _ = load_dataset_file(
|
660 |
+
repo_id=repo_id,
|
661 |
+
file_paths=file_paths,
|
662 |
+
input_type=input_type,
|
663 |
+
num_rows=num_rows,
|
664 |
+
)
|
665 |
+
dataframe = generate_dataset(
|
666 |
+
input_type=input_type,
|
667 |
+
dataframe=dataframe,
|
668 |
+
system_prompt=system_prompt,
|
669 |
+
document_column=document_column,
|
670 |
+
num_turns=num_turns,
|
671 |
+
num_rows=num_rows,
|
672 |
+
temperature=temperature,
|
673 |
+
temperature_completion=temperature_completion,
|
674 |
+
)
|
675 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
676 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
677 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
678 |
+
local_dataset.to_csv(output_csv, index=False)
|
679 |
+
local_dataset.to_json(output_json, index=False)
|
680 |
+
return output_csv, output_json
|
681 |
+
|
682 |
+
|
683 |
def show_system_prompt_visibility():
|
684 |
return {system_prompt: gr.Textbox(visible=True)}
|
685 |
|
|
|
715 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
716 |
|
717 |
|
718 |
+
def show_save_local_button():
|
719 |
+
return {btn_save_local: gr.Button(visible=True)}
|
720 |
+
|
721 |
+
|
722 |
+
def hide_save_local_button():
|
723 |
+
return {btn_save_local: gr.Button(visible=False)}
|
724 |
+
|
725 |
+
|
726 |
+
def show_save_local():
|
727 |
+
gr.update(success_message, min_height=0)
|
728 |
+
return {
|
729 |
+
csv_file: gr.File(visible=True),
|
730 |
+
json_file: gr.File(visible=True),
|
731 |
+
success_message: success_message
|
732 |
+
}
|
733 |
+
|
734 |
+
def hide_save_local():
|
735 |
+
gr.update(success_message, min_height=100)
|
736 |
+
return {
|
737 |
+
csv_file: gr.File(visible=False),
|
738 |
+
json_file: gr.File(visible=False),
|
739 |
+
success_message: success_message,
|
740 |
+
}
|
741 |
+
|
742 |
+
|
743 |
######################
|
744 |
# Gradio UI
|
745 |
######################
|
|
|
849 |
)
|
850 |
document_column = gr.Dropdown(
|
851 |
label="Document Column",
|
852 |
+
info="Select the document column to generate the chat data",
|
853 |
choices=["Load your data first in step 1."],
|
854 |
value="Load your data first in step 1.",
|
855 |
interactive=False,
|
|
|
920 |
btn_push_to_hub = gr.Button(
|
921 |
"Push to Hub", variant="primary", scale=2
|
922 |
)
|
923 |
+
btn_save_local = gr.Button(
|
924 |
+
"Save locally", variant="primary", scale=2, visible=False
|
925 |
+
)
|
926 |
with gr.Column(scale=3):
|
927 |
+
csv_file = gr.File(
|
928 |
+
label="CSV",
|
929 |
+
elem_classes="datasets",
|
930 |
+
visible=False,
|
931 |
+
)
|
932 |
+
json_file = gr.File(
|
933 |
+
label="JSON",
|
934 |
+
elem_classes="datasets",
|
935 |
+
visible=False,
|
936 |
+
)
|
937 |
success_message = gr.Markdown(
|
938 |
+
visible=False,
|
939 |
+
min_height=0 # don't remove this otherwise progress is not visible
|
940 |
)
|
941 |
with gr.Accordion(
|
942 |
"Customize your pipeline with distilabel",
|
|
|
1034 |
fn=validate_push_to_hub,
|
1035 |
inputs=[org_name, repo_name],
|
1036 |
outputs=[success_message],
|
1037 |
+
).success(
|
1038 |
+
fn=hide_save_local,
|
1039 |
+
outputs=[csv_file, json_file, success_message],
|
1040 |
).success(
|
1041 |
fn=hide_success_message,
|
1042 |
outputs=[success_message],
|
|
|
1083 |
outputs=[pipeline_code_ui],
|
1084 |
)
|
1085 |
|
1086 |
+
btn_save_local.click(
|
1087 |
+
fn=hide_success_message,
|
1088 |
+
outputs=[success_message],
|
1089 |
+
).success(
|
1090 |
+
fn=hide_pipeline_code_visibility,
|
1091 |
+
inputs=[],
|
1092 |
+
outputs=[pipeline_code_ui],
|
1093 |
+
).success(
|
1094 |
+
fn=show_save_local,
|
1095 |
+
inputs=[],
|
1096 |
+
outputs=[csv_file, json_file, success_message],
|
1097 |
+
).success(
|
1098 |
+
save_local,
|
1099 |
+
inputs=[
|
1100 |
+
search_in,
|
1101 |
+
file_in,
|
1102 |
+
input_type,
|
1103 |
+
system_prompt,
|
1104 |
+
document_column,
|
1105 |
+
num_turns,
|
1106 |
+
num_rows,
|
1107 |
+
temperature,
|
1108 |
+
repo_name,
|
1109 |
+
temperature_completion,
|
1110 |
+
],
|
1111 |
+
outputs=[csv_file, json_file],
|
1112 |
+
).success(
|
1113 |
+
fn=generate_pipeline_code,
|
1114 |
+
inputs=[
|
1115 |
+
search_in,
|
1116 |
+
input_type,
|
1117 |
+
system_prompt,
|
1118 |
+
document_column,
|
1119 |
+
num_turns,
|
1120 |
+
num_rows,
|
1121 |
+
],
|
1122 |
+
outputs=[pipeline_code],
|
1123 |
+
).success(
|
1124 |
+
fn=show_pipeline_code_visibility,
|
1125 |
+
inputs=[],
|
1126 |
+
outputs=[pipeline_code_ui],
|
1127 |
+
)
|
1128 |
+
|
1129 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
1130 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
1131 |
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
|
|
1138 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
1139 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
1140 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
1141 |
+
if SAVE_LOCAL_DIR is not None:
|
1142 |
+
app.load(fn=show_save_local_button, outputs=btn_save_local)
|
src/synthetic_dataset_generator/apps/rag.py
CHANGED
@@ -24,7 +24,12 @@ from synthetic_dataset_generator.apps.base import (
|
|
24 |
validate_argilla_user_workspace_dataset,
|
25 |
validate_push_to_hub,
|
26 |
)
|
27 |
-
from synthetic_dataset_generator.constants import
|
|
|
|
|
|
|
|
|
|
|
28 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
29 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
30 |
get_embeddings,
|
@@ -156,7 +161,7 @@ def generate_dataset(
|
|
156 |
is_sample=is_sample,
|
157 |
)
|
158 |
response_generator = get_response_generator(
|
159 |
-
temperature
|
160 |
)
|
161 |
if reranking:
|
162 |
reranking_generator = get_sentence_pair_generator(
|
@@ -486,6 +491,49 @@ def push_dataset(
|
|
486 |
return ""
|
487 |
|
488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
def show_system_prompt_visibility():
|
490 |
return {system_prompt: gr.Textbox(visible=True)}
|
491 |
|
@@ -521,6 +569,32 @@ def show_temperature_completion():
|
|
521 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
522 |
|
523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
######################
|
525 |
# Gradio UI
|
526 |
######################
|
@@ -675,10 +749,23 @@ with gr.Blocks() as app:
|
|
675 |
scale=1,
|
676 |
)
|
677 |
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
|
|
|
|
|
|
678 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
679 |
success_message = gr.Markdown(
|
680 |
-
visible=
|
681 |
-
min_height=
|
682 |
)
|
683 |
with gr.Accordion(
|
684 |
"Customize your pipeline with distilabel",
|
@@ -776,6 +863,9 @@ with gr.Blocks() as app:
|
|
776 |
fn=validate_push_to_hub,
|
777 |
inputs=[org_name, repo_name],
|
778 |
outputs=[success_message],
|
|
|
|
|
|
|
779 |
).success(
|
780 |
fn=hide_success_message,
|
781 |
outputs=[success_message],
|
@@ -822,6 +912,49 @@ with gr.Blocks() as app:
|
|
822 |
outputs=[pipeline_code_ui],
|
823 |
)
|
824 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
825 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
826 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
827 |
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
@@ -835,3 +968,5 @@ with gr.Blocks() as app:
|
|
835 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
836 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
837 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
|
|
|
|
|
24 |
validate_argilla_user_workspace_dataset,
|
25 |
validate_push_to_hub,
|
26 |
)
|
27 |
+
from synthetic_dataset_generator.constants import (
|
28 |
+
DEFAULT_BATCH_SIZE,
|
29 |
+
MODEL,
|
30 |
+
MODEL_COMPLETION,
|
31 |
+
SAVE_LOCAL_DIR,
|
32 |
+
)
|
33 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
34 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
35 |
get_embeddings,
|
|
|
161 |
is_sample=is_sample,
|
162 |
)
|
163 |
response_generator = get_response_generator(
|
164 |
+
temperature=temperature_completion or temperature, is_sample=is_sample
|
165 |
)
|
166 |
if reranking:
|
167 |
reranking_generator = get_sentence_pair_generator(
|
|
|
491 |
return ""
|
492 |
|
493 |
|
494 |
+
def save_local(
|
495 |
+
repo_id: str,
|
496 |
+
file_paths: list[str],
|
497 |
+
input_type: str,
|
498 |
+
system_prompt: str,
|
499 |
+
document_column: str,
|
500 |
+
retrieval_reranking: list[str],
|
501 |
+
num_rows: int,
|
502 |
+
temperature: float,
|
503 |
+
repo_name: str,
|
504 |
+
temperature_completion: float,
|
505 |
+
) -> pd.DataFrame:
|
506 |
+
retrieval = "Retrieval" in retrieval_reranking
|
507 |
+
reranking = "Reranking" in retrieval_reranking
|
508 |
+
|
509 |
+
if input_type == "prompt-input":
|
510 |
+
dataframe = pd.DataFrame(columns=["context", "question", "response"])
|
511 |
+
else:
|
512 |
+
dataframe, _ = load_dataset_file(
|
513 |
+
repo_id=repo_id,
|
514 |
+
file_paths=file_paths,
|
515 |
+
input_type=input_type,
|
516 |
+
num_rows=num_rows,
|
517 |
+
)
|
518 |
+
dataframe = generate_dataset(
|
519 |
+
input_type=input_type,
|
520 |
+
dataframe=dataframe,
|
521 |
+
system_prompt=system_prompt,
|
522 |
+
document_column=document_column,
|
523 |
+
retrieval=retrieval,
|
524 |
+
reranking=reranking,
|
525 |
+
num_rows=num_rows,
|
526 |
+
temperature=temperature,
|
527 |
+
temperature_completion=temperature_completion,
|
528 |
+
)
|
529 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
530 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
531 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
532 |
+
local_dataset.to_csv(output_csv, index=False)
|
533 |
+
local_dataset.to_json(output_json, index=False)
|
534 |
+
return output_csv, output_json
|
535 |
+
|
536 |
+
|
537 |
def show_system_prompt_visibility():
|
538 |
return {system_prompt: gr.Textbox(visible=True)}
|
539 |
|
|
|
569 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
570 |
|
571 |
|
572 |
+
def show_save_local_button():
|
573 |
+
return {btn_save_local: gr.Button(visible=True)}
|
574 |
+
|
575 |
+
|
576 |
+
def hide_save_local_button():
|
577 |
+
return {btn_save_local: gr.Button(visible=False)}
|
578 |
+
|
579 |
+
|
580 |
+
def show_save_local():
|
581 |
+
gr.update(success_message, min_height=0)
|
582 |
+
return {
|
583 |
+
csv_file: gr.File(visible=True),
|
584 |
+
json_file: gr.File(visible=True),
|
585 |
+
success_message: success_message,
|
586 |
+
}
|
587 |
+
|
588 |
+
|
589 |
+
def hide_save_local():
|
590 |
+
gr.update(success_message, min_height=100)
|
591 |
+
return {
|
592 |
+
csv_file: gr.File(visible=False),
|
593 |
+
json_file: gr.File(visible=False),
|
594 |
+
success_message: success_message,
|
595 |
+
}
|
596 |
+
|
597 |
+
|
598 |
######################
|
599 |
# Gradio UI
|
600 |
######################
|
|
|
749 |
scale=1,
|
750 |
)
|
751 |
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
752 |
+
btn_save_local = gr.Button(
|
753 |
+
"Save locally", variant="primary", scale=2, visible=False
|
754 |
+
)
|
755 |
with gr.Column(scale=3):
|
756 |
+
csv_file = gr.File(
|
757 |
+
label="CSV",
|
758 |
+
elem_classes="datasets",
|
759 |
+
visible=False,
|
760 |
+
)
|
761 |
+
json_file = gr.File(
|
762 |
+
label="JSON",
|
763 |
+
elem_classes="datasets",
|
764 |
+
visible=False,
|
765 |
+
)
|
766 |
success_message = gr.Markdown(
|
767 |
+
visible=False,
|
768 |
+
min_height=0, # don't remove this otherwise progress is not visible
|
769 |
)
|
770 |
with gr.Accordion(
|
771 |
"Customize your pipeline with distilabel",
|
|
|
863 |
fn=validate_push_to_hub,
|
864 |
inputs=[org_name, repo_name],
|
865 |
outputs=[success_message],
|
866 |
+
).success(
|
867 |
+
fn=hide_save_local,
|
868 |
+
outputs=[csv_file, json_file, success_message],
|
869 |
).success(
|
870 |
fn=hide_success_message,
|
871 |
outputs=[success_message],
|
|
|
912 |
outputs=[pipeline_code_ui],
|
913 |
)
|
914 |
|
915 |
+
btn_save_local.click(
|
916 |
+
fn=hide_success_message,
|
917 |
+
outputs=[success_message],
|
918 |
+
).success(
|
919 |
+
fn=hide_pipeline_code_visibility,
|
920 |
+
inputs=[],
|
921 |
+
outputs=[pipeline_code_ui],
|
922 |
+
).success(
|
923 |
+
fn=show_save_local,
|
924 |
+
inputs=[],
|
925 |
+
outputs=[csv_file, json_file, success_message],
|
926 |
+
).success(
|
927 |
+
save_local,
|
928 |
+
inputs=[
|
929 |
+
search_in,
|
930 |
+
file_in,
|
931 |
+
input_type,
|
932 |
+
system_prompt,
|
933 |
+
document_column,
|
934 |
+
retrieval_reranking,
|
935 |
+
num_rows,
|
936 |
+
temperature,
|
937 |
+
repo_name,
|
938 |
+
temperature_completion,
|
939 |
+
],
|
940 |
+
outputs=[csv_file, json_file],
|
941 |
+
).success(
|
942 |
+
fn=generate_pipeline_code,
|
943 |
+
inputs=[
|
944 |
+
search_in,
|
945 |
+
input_type,
|
946 |
+
system_prompt,
|
947 |
+
document_column,
|
948 |
+
retrieval_reranking,
|
949 |
+
num_rows,
|
950 |
+
],
|
951 |
+
outputs=[pipeline_code],
|
952 |
+
).success(
|
953 |
+
fn=show_pipeline_code_visibility,
|
954 |
+
inputs=[],
|
955 |
+
outputs=[pipeline_code_ui],
|
956 |
+
)
|
957 |
+
|
958 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
959 |
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
|
960 |
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
|
|
|
968 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
969 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
970 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
971 |
+
if SAVE_LOCAL_DIR is not None:
|
972 |
+
app.load(fn=show_save_local_button, outputs=btn_save_local)
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import json
|
|
|
2 |
import random
|
3 |
import uuid
|
4 |
from typing import List, Union
|
@@ -19,7 +20,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
19 |
validate_argilla_user_workspace_dataset,
|
20 |
validate_push_to_hub,
|
21 |
)
|
22 |
-
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
23 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
24 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
25 |
get_embeddings,
|
@@ -195,7 +196,7 @@ def generate_dataset(
|
|
195 |
set(
|
196 |
label.lower().strip()
|
197 |
for label in x
|
198 |
-
if label.lower().strip() in labels
|
199 |
)
|
200 |
)
|
201 |
else:
|
@@ -406,6 +407,33 @@ def push_dataset(
|
|
406 |
return ""
|
407 |
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
def validate_input_labels(labels: List[str]) -> List[str]:
|
410 |
if (
|
411 |
not labels
|
@@ -425,6 +453,32 @@ def hide_pipeline_code_visibility():
|
|
425 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
426 |
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
######################
|
429 |
# Gradio UI
|
430 |
######################
|
@@ -544,10 +598,23 @@ with gr.Blocks() as app:
|
|
544 |
scale=1,
|
545 |
)
|
546 |
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
|
|
|
|
|
|
547 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
success_message = gr.Markdown(
|
549 |
-
visible=
|
550 |
-
min_height=
|
551 |
)
|
552 |
with gr.Accordion(
|
553 |
"Customize your pipeline with distilabel",
|
@@ -600,6 +667,9 @@ with gr.Blocks() as app:
|
|
600 |
fn=validate_input_labels,
|
601 |
inputs=[labels],
|
602 |
outputs=[labels],
|
|
|
|
|
|
|
603 |
).success(
|
604 |
fn=hide_success_message,
|
605 |
outputs=[success_message],
|
@@ -644,6 +714,47 @@ with gr.Blocks() as app:
|
|
644 |
outputs=[pipeline_code_ui],
|
645 |
)
|
646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
gr.on(
|
648 |
triggers=[clear_btn_part.click, clear_btn_full.click],
|
649 |
fn=lambda _: (
|
@@ -660,3 +771,5 @@ with gr.Blocks() as app:
|
|
660 |
app.load(fn=swap_visibility, outputs=main_ui)
|
661 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
662 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
|
|
|
|
|
1 |
import json
|
2 |
+
import os
|
3 |
import random
|
4 |
import uuid
|
5 |
from typing import List, Union
|
|
|
20 |
validate_argilla_user_workspace_dataset,
|
21 |
validate_push_to_hub,
|
22 |
)
|
23 |
+
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, SAVE_LOCAL_DIR
|
24 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
25 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
26 |
get_embeddings,
|
|
|
196 |
set(
|
197 |
label.lower().strip()
|
198 |
for label in x
|
199 |
+
if isinstance(label, str) and label.lower().strip() in labels
|
200 |
)
|
201 |
)
|
202 |
else:
|
|
|
407 |
return ""
|
408 |
|
409 |
|
410 |
+
def save_local(
|
411 |
+
system_prompt: str,
|
412 |
+
difficulty: str,
|
413 |
+
clarity: str,
|
414 |
+
labels: List[str],
|
415 |
+
multi_label: bool,
|
416 |
+
num_rows: int,
|
417 |
+
temperature: float,
|
418 |
+
repo_name: str,
|
419 |
+
) -> pd.DataFrame:
|
420 |
+
dataframe = generate_dataset(
|
421 |
+
system_prompt=system_prompt,
|
422 |
+
difficulty=difficulty,
|
423 |
+
clarity=clarity,
|
424 |
+
multi_label=multi_label,
|
425 |
+
labels=labels,
|
426 |
+
num_rows=num_rows,
|
427 |
+
temperature=temperature,
|
428 |
+
)
|
429 |
+
local_dataset = Dataset.from_pandas(dataframe)
|
430 |
+
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
431 |
+
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
|
432 |
+
local_dataset.to_csv(output_csv, index=False)
|
433 |
+
local_dataset.to_json(output_json, index=False)
|
434 |
+
return output_csv, output_json
|
435 |
+
|
436 |
+
|
437 |
def validate_input_labels(labels: List[str]) -> List[str]:
|
438 |
if (
|
439 |
not labels
|
|
|
453 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
454 |
|
455 |
|
456 |
+
def show_save_local_button():
|
457 |
+
return {btn_save_local: gr.Button(visible=True)}
|
458 |
+
|
459 |
+
|
460 |
+
def hide_save_local_button():
|
461 |
+
return {btn_save_local: gr.Button(visible=False)}
|
462 |
+
|
463 |
+
|
464 |
+
def show_save_local():
|
465 |
+
gr.update(success_message, min_height=0)
|
466 |
+
return {
|
467 |
+
csv_file: gr.File(visible=True),
|
468 |
+
json_file: gr.File(visible=True),
|
469 |
+
success_message: success_message,
|
470 |
+
}
|
471 |
+
|
472 |
+
|
473 |
+
def hide_save_local():
|
474 |
+
gr.update(success_message, min_height=100)
|
475 |
+
return {
|
476 |
+
csv_file: gr.File(visible=False),
|
477 |
+
json_file: gr.File(visible=False),
|
478 |
+
success_message: success_message,
|
479 |
+
}
|
480 |
+
|
481 |
+
|
482 |
######################
|
483 |
# Gradio UI
|
484 |
######################
|
|
|
598 |
scale=1,
|
599 |
)
|
600 |
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
601 |
+
btn_save_local = gr.Button(
|
602 |
+
"Save locally", variant="primary", scale=2, visible=False
|
603 |
+
)
|
604 |
with gr.Column(scale=3):
|
605 |
+
csv_file = gr.File(
|
606 |
+
label="CSV",
|
607 |
+
elem_classes="datasets",
|
608 |
+
visible=False,
|
609 |
+
)
|
610 |
+
json_file = gr.File(
|
611 |
+
label="JSON",
|
612 |
+
elem_classes="datasets",
|
613 |
+
visible=False,
|
614 |
+
)
|
615 |
success_message = gr.Markdown(
|
616 |
+
visible=False,
|
617 |
+
min_height=0, # don't remove this otherwise progress is not visible
|
618 |
)
|
619 |
with gr.Accordion(
|
620 |
"Customize your pipeline with distilabel",
|
|
|
667 |
fn=validate_input_labels,
|
668 |
inputs=[labels],
|
669 |
outputs=[labels],
|
670 |
+
).success(
|
671 |
+
fn=hide_save_local,
|
672 |
+
outputs=[csv_file, json_file, success_message],
|
673 |
).success(
|
674 |
fn=hide_success_message,
|
675 |
outputs=[success_message],
|
|
|
714 |
outputs=[pipeline_code_ui],
|
715 |
)
|
716 |
|
717 |
+
btn_save_local.click(
|
718 |
+
fn=hide_success_message,
|
719 |
+
outputs=[success_message],
|
720 |
+
).success(
|
721 |
+
fn=hide_pipeline_code_visibility,
|
722 |
+
inputs=[],
|
723 |
+
outputs=[pipeline_code_ui],
|
724 |
+
).success(
|
725 |
+
fn=show_save_local,
|
726 |
+
inputs=[],
|
727 |
+
outputs=[csv_file, json_file, success_message],
|
728 |
+
).success(
|
729 |
+
save_local,
|
730 |
+
inputs=[
|
731 |
+
system_prompt,
|
732 |
+
difficulty,
|
733 |
+
clarity,
|
734 |
+
labels,
|
735 |
+
multi_label,
|
736 |
+
num_rows,
|
737 |
+
temperature,
|
738 |
+
repo_name,
|
739 |
+
],
|
740 |
+
outputs=[csv_file, json_file],
|
741 |
+
).success(
|
742 |
+
fn=generate_pipeline_code,
|
743 |
+
inputs=[
|
744 |
+
system_prompt,
|
745 |
+
difficulty,
|
746 |
+
clarity,
|
747 |
+
labels,
|
748 |
+
multi_label,
|
749 |
+
num_rows,
|
750 |
+
],
|
751 |
+
outputs=[pipeline_code],
|
752 |
+
).success(
|
753 |
+
fn=show_pipeline_code_visibility,
|
754 |
+
inputs=[],
|
755 |
+
outputs=[pipeline_code_ui],
|
756 |
+
)
|
757 |
+
|
758 |
gr.on(
|
759 |
triggers=[clear_btn_part.click, clear_btn_full.click],
|
760 |
fn=lambda _: (
|
|
|
771 |
app.load(fn=swap_visibility, outputs=main_ui)
|
772 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
773 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
774 |
+
if SAVE_LOCAL_DIR is not None:
|
775 |
+
app.load(fn=show_save_local_button, outputs=btn_save_local)
|
src/synthetic_dataset_generator/constants.py
CHANGED
@@ -8,6 +8,9 @@ MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
|
8 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
9 |
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
10 |
|
|
|
|
|
|
|
11 |
# Models
|
12 |
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
13 |
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
|
|
|
8 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
9 |
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
10 |
|
11 |
+
# Directory to locally save the generated data
|
12 |
+
SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)
|
13 |
+
|
14 |
# Models
|
15 |
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
16 |
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
|
src/synthetic_dataset_generator/pipelines/base.py
CHANGED
@@ -87,10 +87,17 @@ def _get_llm(
|
|
87 |
):
|
88 |
model = MODEL_COMPLETION if is_completion else MODEL
|
89 |
tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
llm = OpenAILLM(
|
92 |
model=model,
|
93 |
-
base_url=
|
94 |
api_key=_get_next_api_key(),
|
95 |
structured_output=structured_output,
|
96 |
**kwargs,
|
@@ -103,7 +110,7 @@ def _get_llm(
|
|
103 |
del kwargs["generation_kwargs"]["stop_sequences"]
|
104 |
if "do_sample" in kwargs["generation_kwargs"]:
|
105 |
del kwargs["generation_kwargs"]["do_sample"]
|
106 |
-
elif
|
107 |
if "generation_kwargs" in kwargs:
|
108 |
if "max_new_tokens" in kwargs["generation_kwargs"]:
|
109 |
kwargs["generation_kwargs"]["num_predict"] = kwargs[
|
@@ -123,32 +130,28 @@ def _get_llm(
|
|
123 |
kwargs["generation_kwargs"]["options"] = options
|
124 |
llm = OllamaLLM(
|
125 |
model=model,
|
126 |
-
host=
|
127 |
tokenizer_id=tokenizer_id,
|
128 |
use_magpie_template=use_magpie_template,
|
129 |
structured_output=structured_output,
|
130 |
**kwargs,
|
131 |
)
|
132 |
-
elif
|
133 |
kwargs["generation_kwargs"]["do_sample"] = True
|
134 |
llm = InferenceEndpointsLLM(
|
135 |
api_key=_get_next_api_key(),
|
136 |
-
base_url=
|
137 |
-
HUGGINGFACE_BASE_URL_COMPLETION
|
138 |
-
if is_completion
|
139 |
-
else HUGGINGFACE_BASE_URL
|
140 |
-
),
|
141 |
tokenizer_id=tokenizer_id,
|
142 |
use_magpie_template=use_magpie_template,
|
143 |
structured_output=structured_output,
|
144 |
**kwargs,
|
145 |
)
|
146 |
-
elif
|
147 |
if "generation_kwargs" in kwargs:
|
148 |
if "do_sample" in kwargs["generation_kwargs"]:
|
149 |
del kwargs["generation_kwargs"]["do_sample"]
|
150 |
llm = ClientvLLM(
|
151 |
-
base_url=
|
152 |
model=model,
|
153 |
tokenizer=tokenizer_id,
|
154 |
api_key=_get_next_api_key(),
|
|
|
87 |
):
|
88 |
model = MODEL_COMPLETION if is_completion else MODEL
|
89 |
tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
|
90 |
+
base_urls = {
|
91 |
+
"openai": OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
|
92 |
+
"ollama": OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
|
93 |
+
"huggingface": HUGGINGFACE_BASE_URL_COMPLETION if is_completion else HUGGINGFACE_BASE_URL,
|
94 |
+
"vllm": VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
|
95 |
+
}
|
96 |
+
|
97 |
+
if base_urls["openai"]:
|
98 |
llm = OpenAILLM(
|
99 |
model=model,
|
100 |
+
base_url=base_urls["openai"],
|
101 |
api_key=_get_next_api_key(),
|
102 |
structured_output=structured_output,
|
103 |
**kwargs,
|
|
|
110 |
del kwargs["generation_kwargs"]["stop_sequences"]
|
111 |
if "do_sample" in kwargs["generation_kwargs"]:
|
112 |
del kwargs["generation_kwargs"]["do_sample"]
|
113 |
+
elif base_urls["ollama"]:
|
114 |
if "generation_kwargs" in kwargs:
|
115 |
if "max_new_tokens" in kwargs["generation_kwargs"]:
|
116 |
kwargs["generation_kwargs"]["num_predict"] = kwargs[
|
|
|
130 |
kwargs["generation_kwargs"]["options"] = options
|
131 |
llm = OllamaLLM(
|
132 |
model=model,
|
133 |
+
host=base_urls["ollama"],
|
134 |
tokenizer_id=tokenizer_id,
|
135 |
use_magpie_template=use_magpie_template,
|
136 |
structured_output=structured_output,
|
137 |
**kwargs,
|
138 |
)
|
139 |
+
elif base_urls["huggingface"]:
|
140 |
kwargs["generation_kwargs"]["do_sample"] = True
|
141 |
llm = InferenceEndpointsLLM(
|
142 |
api_key=_get_next_api_key(),
|
143 |
+
base_url=base_urls["huggingface"],
|
|
|
|
|
|
|
|
|
144 |
tokenizer_id=tokenizer_id,
|
145 |
use_magpie_template=use_magpie_template,
|
146 |
structured_output=structured_output,
|
147 |
**kwargs,
|
148 |
)
|
149 |
+
elif base_urls["vllm"]:
|
150 |
if "generation_kwargs" in kwargs:
|
151 |
if "do_sample" in kwargs["generation_kwargs"]:
|
152 |
del kwargs["generation_kwargs"]["do_sample"]
|
153 |
llm = ClientvLLM(
|
154 |
+
base_url=base_urls["vllm"],
|
155 |
model=model,
|
156 |
tokenizer=tokenizer_id,
|
157 |
api_key=_get_next_api_key(),
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -109,7 +109,7 @@ def get_labeller_generator(system_prompt: str, labels: List[str], multi_label: b
|
|
109 |
"temperature": 0.01,
|
110 |
"max_new_tokens": MAX_NUM_TOKENS,
|
111 |
}
|
112 |
-
llm = _get_llm(
|
113 |
labeller_generator = TextClassification(
|
114 |
llm=llm,
|
115 |
context=system_prompt,
|
|
|
109 |
"temperature": 0.01,
|
110 |
"max_new_tokens": MAX_NUM_TOKENS,
|
111 |
}
|
112 |
+
llm = _get_llm(generation_kwargs=generation_kwargs)
|
113 |
labeller_generator = TextClassification(
|
114 |
llm=llm,
|
115 |
context=system_prompt,
|