add lacking typing
Browse files- src/synthetic_dataset_generator/apps/base.py +2 -2
- src/synthetic_dataset_generator/apps/chat.py +9 -4
- src/synthetic_dataset_generator/apps/rag.py +2 -4
- src/synthetic_dataset_generator/apps/textcat.py +15 -10
- src/synthetic_dataset_generator/pipelines/chat.py +6 -4
- src/synthetic_dataset_generator/pipelines/eval.py +13 -3
- src/synthetic_dataset_generator/pipelines/rag.py +3 -4
src/synthetic_dataset_generator/apps/base.py
CHANGED
@@ -64,7 +64,7 @@ def push_pipeline_code_to_hub(
|
|
64 |
progress(1.0, desc="Pipeline code uploaded")
|
65 |
|
66 |
|
67 |
-
def validate_push_to_hub(org_name, repo_name):
|
68 |
repo_id = (
|
69 |
f"{org_name}/{repo_name}"
|
70 |
if repo_name is not None and org_name is not None
|
@@ -93,7 +93,7 @@ def combine_datasets(
|
|
93 |
return dataset
|
94 |
|
95 |
|
96 |
-
def show_success_message(org_name, repo_name) -> gr.Markdown:
|
97 |
client = get_argilla_client()
|
98 |
if client is None:
|
99 |
return gr.Markdown(
|
|
|
64 |
progress(1.0, desc="Pipeline code uploaded")
|
65 |
|
66 |
|
67 |
+
def validate_push_to_hub(org_name: str, repo_name: str):
|
68 |
repo_id = (
|
69 |
f"{org_name}/{repo_name}"
|
70 |
if repo_name is not None and org_name is not None
|
|
|
93 |
return dataset
|
94 |
|
95 |
|
96 |
+
def show_success_message(org_name: str, repo_name: str) -> gr.Markdown:
|
97 |
client = get_argilla_client()
|
98 |
if client is None:
|
99 |
return gr.Markdown(
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
@@ -60,7 +60,7 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
60 |
return dataframe
|
61 |
|
62 |
|
63 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
64 |
progress(0.1, desc="Initializing")
|
65 |
generate_description = get_prompt_generator()
|
66 |
progress(0.5, desc="Generating")
|
@@ -77,7 +77,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
77 |
return result
|
78 |
|
79 |
|
80 |
-
def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
|
81 |
progress(0.1, desc="Generating sample dataset")
|
82 |
dataframe = generate_dataset(
|
83 |
system_prompt=system_prompt,
|
@@ -109,7 +109,7 @@ def generate_dataset(
|
|
109 |
num_rows = test_max_num_rows(num_rows)
|
110 |
progress(0.0, desc="(1/2) Generating instructions")
|
111 |
magpie_generator = get_magpie_generator(
|
112 |
-
|
113 |
)
|
114 |
response_generator = get_response_generator(
|
115 |
system_prompt, num_turns, temperature, is_sample
|
@@ -267,7 +267,12 @@ def push_dataset(
|
|
267 |
temperature=temperature,
|
268 |
)
|
269 |
push_dataset_to_hub(
|
270 |
-
dataframe,
|
|
|
|
|
|
|
|
|
|
|
271 |
)
|
272 |
try:
|
273 |
progress(0.1, desc="Setting up user and workspace")
|
|
|
60 |
return dataframe
|
61 |
|
62 |
|
63 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
64 |
progress(0.1, desc="Initializing")
|
65 |
generate_description = get_prompt_generator()
|
66 |
progress(0.5, desc="Generating")
|
|
|
77 |
return result
|
78 |
|
79 |
|
80 |
+
def generate_sample_dataset(system_prompt: str, num_turns: int, progress=gr.Progress()):
|
81 |
progress(0.1, desc="Generating sample dataset")
|
82 |
dataframe = generate_dataset(
|
83 |
system_prompt=system_prompt,
|
|
|
109 |
num_rows = test_max_num_rows(num_rows)
|
110 |
progress(0.0, desc="(1/2) Generating instructions")
|
111 |
magpie_generator = get_magpie_generator(
|
112 |
+
num_turns, temperature, is_sample
|
113 |
)
|
114 |
response_generator = get_response_generator(
|
115 |
system_prompt, num_turns, temperature, is_sample
|
|
|
267 |
temperature=temperature,
|
268 |
)
|
269 |
push_dataset_to_hub(
|
270 |
+
dataframe=dataframe,
|
271 |
+
org_name=org_name,
|
272 |
+
repo_name=repo_name,
|
273 |
+
oauth_token=oauth_token,
|
274 |
+
private=private,
|
275 |
+
pipeline_code=pipeline_code,
|
276 |
)
|
277 |
try:
|
278 |
progress(0.1, desc="Setting up user and workspace")
|
src/synthetic_dataset_generator/apps/rag.py
CHANGED
@@ -101,7 +101,7 @@ def _load_dataset_from_hub(
|
|
101 |
)
|
102 |
|
103 |
|
104 |
-
def _preprocess_input_data(file_paths, num_rows, progress=gr.Progress(track_tqdm=True)):
|
105 |
data = {}
|
106 |
total_chunks = 0
|
107 |
|
@@ -131,7 +131,7 @@ def _preprocess_input_data(file_paths, num_rows, progress=gr.Progress(track_tqdm
|
|
131 |
)
|
132 |
|
133 |
|
134 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
135 |
progress(0.1, desc="Initializing")
|
136 |
generate_description = get_prompt_generator()
|
137 |
progress(0.5, desc="Generating")
|
@@ -753,7 +753,6 @@ with gr.Blocks() as app:
|
|
753 |
) as pipeline_code_ui:
|
754 |
code = generate_pipeline_code(
|
755 |
repo_id=search_in.value,
|
756 |
-
file_paths=file_in.value,
|
757 |
input_type=input_type.value,
|
758 |
system_prompt=system_prompt.value,
|
759 |
document_column=document_column.value,
|
@@ -891,7 +890,6 @@ with gr.Blocks() as app:
|
|
891 |
fn=generate_pipeline_code,
|
892 |
inputs=[
|
893 |
search_in,
|
894 |
-
file_in,
|
895 |
input_type,
|
896 |
system_prompt,
|
897 |
document_column,
|
|
|
101 |
)
|
102 |
|
103 |
|
104 |
+
def _preprocess_input_data(file_paths: list[str], num_rows: int, progress=gr.Progress(track_tqdm=True)):
|
105 |
data = {}
|
106 |
total_chunks = 0
|
107 |
|
|
|
131 |
)
|
132 |
|
133 |
|
134 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
135 |
progress(0.1, desc="Initializing")
|
136 |
generate_description = get_prompt_generator()
|
137 |
progress(0.5, desc="Generating")
|
|
|
753 |
) as pipeline_code_ui:
|
754 |
code = generate_pipeline_code(
|
755 |
repo_id=search_in.value,
|
|
|
756 |
input_type=input_type.value,
|
757 |
system_prompt=system_prompt.value,
|
758 |
document_column=document_column.value,
|
|
|
890 |
fn=generate_pipeline_code,
|
891 |
inputs=[
|
892 |
search_in,
|
|
|
893 |
input_type,
|
894 |
system_prompt,
|
895 |
document_column,
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -49,7 +49,7 @@ def _get_dataframe():
|
|
49 |
)
|
50 |
|
51 |
|
52 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
53 |
progress(0.0, desc="Starting")
|
54 |
progress(0.3, desc="Initializing")
|
55 |
generate_description = get_prompt_generator()
|
@@ -71,7 +71,12 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
71 |
|
72 |
|
73 |
def generate_sample_dataset(
|
74 |
-
system_prompt
|
|
|
|
|
|
|
|
|
|
|
75 |
):
|
76 |
dataframe = generate_dataset(
|
77 |
system_prompt=system_prompt,
|
@@ -294,14 +299,14 @@ def push_dataset(
|
|
294 |
temperature=temperature,
|
295 |
)
|
296 |
push_dataset_to_hub(
|
297 |
-
dataframe,
|
298 |
-
org_name,
|
299 |
-
repo_name,
|
300 |
-
multi_label,
|
301 |
-
labels,
|
302 |
-
oauth_token,
|
303 |
-
private,
|
304 |
-
pipeline_code,
|
305 |
)
|
306 |
|
307 |
dataframe = dataframe[
|
|
|
49 |
)
|
50 |
|
51 |
|
52 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
53 |
progress(0.0, desc="Starting")
|
54 |
progress(0.3, desc="Initializing")
|
55 |
generate_description = get_prompt_generator()
|
|
|
71 |
|
72 |
|
73 |
def generate_sample_dataset(
|
74 |
+
system_prompt: str,
|
75 |
+
difficulty: str,
|
76 |
+
clarity: str,
|
77 |
+
labels: List[str],
|
78 |
+
multi_label: bool,
|
79 |
+
progress=gr.Progress(),
|
80 |
):
|
81 |
dataframe = generate_dataset(
|
82 |
system_prompt=system_prompt,
|
|
|
299 |
temperature=temperature,
|
300 |
)
|
301 |
push_dataset_to_hub(
|
302 |
+
dataframe=dataframe,
|
303 |
+
org_name=org_name,
|
304 |
+
repo_name=repo_name,
|
305 |
+
multi_label=multi_label,
|
306 |
+
labels=labels,
|
307 |
+
oauth_token=oauth_token,
|
308 |
+
private=private,
|
309 |
+
pipeline_code=pipeline_code,
|
310 |
)
|
311 |
|
312 |
dataframe = dataframe[
|
src/synthetic_dataset_generator/pipelines/chat.py
CHANGED
@@ -140,7 +140,7 @@ else:
|
|
140 |
]
|
141 |
|
142 |
|
143 |
-
def _get_output_mappings(num_turns):
|
144 |
if num_turns == 1:
|
145 |
return {"instruction": "prompt", "response": "completion"}
|
146 |
else:
|
@@ -162,7 +162,7 @@ def get_prompt_generator():
|
|
162 |
return prompt_generator
|
163 |
|
164 |
|
165 |
-
def get_magpie_generator(
|
166 |
input_mappings = _get_output_mappings(num_turns)
|
167 |
output_mappings = input_mappings.copy()
|
168 |
if num_turns == 1:
|
@@ -203,7 +203,9 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
203 |
return magpie_generator
|
204 |
|
205 |
|
206 |
-
def get_response_generator(
|
|
|
|
|
207 |
if num_turns == 1:
|
208 |
generation_kwargs = {
|
209 |
"temperature": temperature,
|
@@ -229,7 +231,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
|
229 |
return response_generator
|
230 |
|
231 |
|
232 |
-
def generate_pipeline_code(system_prompt, num_turns, num_rows):
|
233 |
input_mappings = _get_output_mappings(num_turns)
|
234 |
|
235 |
code = f"""
|
|
|
140 |
]
|
141 |
|
142 |
|
143 |
+
def _get_output_mappings(num_turns: int):
|
144 |
if num_turns == 1:
|
145 |
return {"instruction": "prompt", "response": "completion"}
|
146 |
else:
|
|
|
162 |
return prompt_generator
|
163 |
|
164 |
|
165 |
+
def get_magpie_generator(num_turns: int, temperature: float, is_sample: bool):
|
166 |
input_mappings = _get_output_mappings(num_turns)
|
167 |
output_mappings = input_mappings.copy()
|
168 |
if num_turns == 1:
|
|
|
203 |
return magpie_generator
|
204 |
|
205 |
|
206 |
+
def get_response_generator(
|
207 |
+
system_prompt: str, num_turns: int, temperature: float, is_sample: bool
|
208 |
+
):
|
209 |
if num_turns == 1:
|
210 |
generation_kwargs = {
|
211 |
"temperature": temperature,
|
|
|
231 |
return response_generator
|
232 |
|
233 |
|
234 |
+
def generate_pipeline_code(system_prompt: str, num_turns: int, num_rows: int):
|
235 |
input_mappings = _get_output_mappings(num_turns)
|
236 |
|
237 |
code = f"""
|
src/synthetic_dataset_generator/pipelines/eval.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
2 |
from distilabel.models import InferenceEndpointsLLM
|
3 |
from distilabel.steps.tasks import (
|
@@ -10,7 +12,7 @@ from synthetic_dataset_generator.pipelines.base import _get_next_api_key
|
|
10 |
from synthetic_dataset_generator.utils import extract_column_names
|
11 |
|
12 |
|
13 |
-
def get_ultrafeedback_evaluator(aspect, is_sample):
|
14 |
ultrafeedback_evaluator = UltraFeedback(
|
15 |
llm=InferenceEndpointsLLM(
|
16 |
model_id=MODEL,
|
@@ -27,7 +29,9 @@ def get_ultrafeedback_evaluator(aspect, is_sample):
|
|
27 |
return ultrafeedback_evaluator
|
28 |
|
29 |
|
30 |
-
def get_custom_evaluator(
|
|
|
|
|
31 |
custom_evaluator = TextGeneration(
|
32 |
llm=InferenceEndpointsLLM(
|
33 |
model_id=MODEL,
|
@@ -47,7 +51,13 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample)
|
|
47 |
|
48 |
|
49 |
def generate_ultrafeedback_pipeline_code(
|
50 |
-
repo_id
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
):
|
52 |
if len(aspects) == 1:
|
53 |
code = f"""
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
4 |
from distilabel.models import InferenceEndpointsLLM
|
5 |
from distilabel.steps.tasks import (
|
|
|
12 |
from synthetic_dataset_generator.utils import extract_column_names
|
13 |
|
14 |
|
15 |
+
def get_ultrafeedback_evaluator(aspect: str, is_sample: bool):
|
16 |
ultrafeedback_evaluator = UltraFeedback(
|
17 |
llm=InferenceEndpointsLLM(
|
18 |
model_id=MODEL,
|
|
|
29 |
return ultrafeedback_evaluator
|
30 |
|
31 |
|
32 |
+
def get_custom_evaluator(
|
33 |
+
prompt_template: str, structured_output: dict, columns: List[str], is_sample: bool
|
34 |
+
):
|
35 |
custom_evaluator = TextGeneration(
|
36 |
llm=InferenceEndpointsLLM(
|
37 |
model_id=MODEL,
|
|
|
51 |
|
52 |
|
53 |
def generate_ultrafeedback_pipeline_code(
|
54 |
+
repo_id: str,
|
55 |
+
subset: str,
|
56 |
+
split: str,
|
57 |
+
aspects: List[str],
|
58 |
+
instruction_column: str,
|
59 |
+
response_columns: str,
|
60 |
+
num_rows: int,
|
61 |
):
|
62 |
if len(aspects) == 1:
|
63 |
code = f"""
|
src/synthetic_dataset_generator/pipelines/rag.py
CHANGED
@@ -87,7 +87,7 @@ def get_prompt_generator():
|
|
87 |
return text_generator
|
88 |
|
89 |
|
90 |
-
def get_chunks_generator(temperature, is_sample):
|
91 |
generation_kwargs = {
|
92 |
"temperature": temperature,
|
93 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
@@ -104,7 +104,7 @@ def get_chunks_generator(temperature, is_sample):
|
|
104 |
return text_generator
|
105 |
|
106 |
|
107 |
-
def get_sentence_pair_generator(action, triplet, temperature, is_sample):
|
108 |
generation_kwargs = {
|
109 |
"temperature": temperature,
|
110 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
@@ -119,7 +119,7 @@ def get_sentence_pair_generator(action, triplet, temperature, is_sample):
|
|
119 |
return sentence_pair_generator
|
120 |
|
121 |
|
122 |
-
def get_response_generator(temperature, is_sample):
|
123 |
generation_kwargs = {
|
124 |
"temperature": temperature,
|
125 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
@@ -138,7 +138,6 @@ def get_response_generator(temperature, is_sample):
|
|
138 |
|
139 |
def generate_pipeline_code(
|
140 |
repo_id: str,
|
141 |
-
file_paths: List[str],
|
142 |
input_type: str,
|
143 |
system_prompt: str,
|
144 |
document_column: str,
|
|
|
87 |
return text_generator
|
88 |
|
89 |
|
90 |
+
def get_chunks_generator(temperature: float, is_sample: bool):
|
91 |
generation_kwargs = {
|
92 |
"temperature": temperature,
|
93 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
|
|
104 |
return text_generator
|
105 |
|
106 |
|
107 |
+
def get_sentence_pair_generator(action: str, triplet: bool, temperature: float, is_sample: bool):
|
108 |
generation_kwargs = {
|
109 |
"temperature": temperature,
|
110 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
|
|
119 |
return sentence_pair_generator
|
120 |
|
121 |
|
122 |
+
def get_response_generator(temperature: float, is_sample: bool):
|
123 |
generation_kwargs = {
|
124 |
"temperature": temperature,
|
125 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
|
|
138 |
|
139 |
def generate_pipeline_code(
|
140 |
repo_id: str,
|
|
|
141 |
input_type: str,
|
142 |
system_prompt: str,
|
143 |
document_column: str,
|