Merge pull request #28 from argilla-io/bug/fix-bugs
Browse files- src/synthetic_dataset_generator/apps/base.py +2 -2
- src/synthetic_dataset_generator/apps/chat.py +81 -76
- src/synthetic_dataset_generator/apps/eval.py +1 -0
- src/synthetic_dataset_generator/apps/rag.py +6 -5
- src/synthetic_dataset_generator/apps/textcat.py +16 -10
- src/synthetic_dataset_generator/pipelines/chat.py +6 -4
- src/synthetic_dataset_generator/pipelines/eval.py +13 -3
- src/synthetic_dataset_generator/pipelines/rag.py +3 -4
- src/synthetic_dataset_generator/pipelines/textcat.py +4 -2
src/synthetic_dataset_generator/apps/base.py
CHANGED
@@ -64,7 +64,7 @@ def push_pipeline_code_to_hub(
|
|
64 |
progress(1.0, desc="Pipeline code uploaded")
|
65 |
|
66 |
|
67 |
-
def validate_push_to_hub(org_name, repo_name):
|
68 |
repo_id = (
|
69 |
f"{org_name}/{repo_name}"
|
70 |
if repo_name is not None and org_name is not None
|
@@ -93,7 +93,7 @@ def combine_datasets(
|
|
93 |
return dataset
|
94 |
|
95 |
|
96 |
-
def show_success_message(org_name, repo_name) -> gr.Markdown:
|
97 |
client = get_argilla_client()
|
98 |
if client is None:
|
99 |
return gr.Markdown(
|
|
|
64 |
progress(1.0, desc="Pipeline code uploaded")
|
65 |
|
66 |
|
67 |
+
def validate_push_to_hub(org_name: str, repo_name: str):
|
68 |
repo_id = (
|
69 |
f"{org_name}/{repo_name}"
|
70 |
if repo_name is not None and org_name is not None
|
|
|
93 |
return dataset
|
94 |
|
95 |
|
96 |
+
def show_success_message(org_name: str, repo_name: str) -> gr.Markdown:
|
97 |
client = get_argilla_client()
|
98 |
if client is None:
|
99 |
return gr.Markdown(
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
@@ -60,7 +60,7 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
60 |
return dataframe
|
61 |
|
62 |
|
63 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
64 |
progress(0.1, desc="Initializing")
|
65 |
generate_description = get_prompt_generator()
|
66 |
progress(0.5, desc="Generating")
|
@@ -77,7 +77,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
77 |
return result
|
78 |
|
79 |
|
80 |
-
def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
|
81 |
progress(0.1, desc="Generating sample dataset")
|
82 |
dataframe = generate_dataset(
|
83 |
system_prompt=system_prompt,
|
@@ -109,7 +109,7 @@ def generate_dataset(
|
|
109 |
num_rows = test_max_num_rows(num_rows)
|
110 |
progress(0.0, desc="(1/2) Generating instructions")
|
111 |
magpie_generator = get_magpie_generator(
|
112 |
-
|
113 |
)
|
114 |
response_generator = get_response_generator(
|
115 |
system_prompt, num_turns, temperature, is_sample
|
@@ -267,7 +267,12 @@ def push_dataset(
|
|
267 |
temperature=temperature,
|
268 |
)
|
269 |
push_dataset_to_hub(
|
270 |
-
dataframe,
|
|
|
|
|
|
|
|
|
|
|
271 |
)
|
272 |
try:
|
273 |
progress(0.1, desc="Setting up user and workspace")
|
@@ -524,77 +529,77 @@ with gr.Blocks() as app:
|
|
524 |
label="Distilabel Pipeline Code",
|
525 |
)
|
526 |
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
|
|
60 |
return dataframe
|
61 |
|
62 |
|
63 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
64 |
progress(0.1, desc="Initializing")
|
65 |
generate_description = get_prompt_generator()
|
66 |
progress(0.5, desc="Generating")
|
|
|
77 |
return result
|
78 |
|
79 |
|
80 |
+
def generate_sample_dataset(system_prompt: str, num_turns: int, progress=gr.Progress()):
|
81 |
progress(0.1, desc="Generating sample dataset")
|
82 |
dataframe = generate_dataset(
|
83 |
system_prompt=system_prompt,
|
|
|
109 |
num_rows = test_max_num_rows(num_rows)
|
110 |
progress(0.0, desc="(1/2) Generating instructions")
|
111 |
magpie_generator = get_magpie_generator(
|
112 |
+
num_turns, temperature, is_sample
|
113 |
)
|
114 |
response_generator = get_response_generator(
|
115 |
system_prompt, num_turns, temperature, is_sample
|
|
|
267 |
temperature=temperature,
|
268 |
)
|
269 |
push_dataset_to_hub(
|
270 |
+
dataframe=dataframe,
|
271 |
+
org_name=org_name,
|
272 |
+
repo_name=repo_name,
|
273 |
+
oauth_token=oauth_token,
|
274 |
+
private=private,
|
275 |
+
pipeline_code=pipeline_code,
|
276 |
)
|
277 |
try:
|
278 |
progress(0.1, desc="Setting up user and workspace")
|
|
|
529 |
label="Distilabel Pipeline Code",
|
530 |
)
|
531 |
|
532 |
+
load_btn.click(
|
533 |
+
fn=generate_system_prompt,
|
534 |
+
inputs=[dataset_description],
|
535 |
+
outputs=[system_prompt],
|
536 |
+
show_progress=True,
|
537 |
+
).then(
|
538 |
+
fn=generate_sample_dataset,
|
539 |
+
inputs=[system_prompt, num_turns],
|
540 |
+
outputs=[dataframe],
|
541 |
+
show_progress=True,
|
542 |
+
)
|
543 |
|
544 |
+
btn_apply_to_sample_dataset.click(
|
545 |
+
fn=generate_sample_dataset,
|
546 |
+
inputs=[system_prompt, num_turns],
|
547 |
+
outputs=[dataframe],
|
548 |
+
show_progress=True,
|
549 |
+
)
|
550 |
|
551 |
+
btn_push_to_hub.click(
|
552 |
+
fn=validate_argilla_user_workspace_dataset,
|
553 |
+
inputs=[repo_name],
|
554 |
+
outputs=[success_message],
|
555 |
+
show_progress=True,
|
556 |
+
).then(
|
557 |
+
fn=validate_push_to_hub,
|
558 |
+
inputs=[org_name, repo_name],
|
559 |
+
outputs=[success_message],
|
560 |
+
show_progress=True,
|
561 |
+
).success(
|
562 |
+
fn=hide_success_message,
|
563 |
+
outputs=[success_message],
|
564 |
+
show_progress=True,
|
565 |
+
).success(
|
566 |
+
fn=hide_pipeline_code_visibility,
|
567 |
+
inputs=[],
|
568 |
+
outputs=[pipeline_code_ui],
|
569 |
+
show_progress=True,
|
570 |
+
).success(
|
571 |
+
fn=push_dataset,
|
572 |
+
inputs=[
|
573 |
+
org_name,
|
574 |
+
repo_name,
|
575 |
+
system_prompt,
|
576 |
+
num_turns,
|
577 |
+
num_rows,
|
578 |
+
private,
|
579 |
+
temperature,
|
580 |
+
pipeline_code,
|
581 |
+
],
|
582 |
+
outputs=[success_message],
|
583 |
+
show_progress=True,
|
584 |
+
).success(
|
585 |
+
fn=show_success_message,
|
586 |
+
inputs=[org_name, repo_name],
|
587 |
+
outputs=[success_message],
|
588 |
+
).success(
|
589 |
+
fn=generate_pipeline_code,
|
590 |
+
inputs=[system_prompt, num_turns, num_rows],
|
591 |
+
outputs=[pipeline_code],
|
592 |
+
).success(
|
593 |
+
fn=show_pipeline_code_visibility,
|
594 |
+
inputs=[],
|
595 |
+
outputs=[pipeline_code_ui],
|
596 |
+
)
|
597 |
+
gr.on(
|
598 |
+
triggers=[clear_btn_part.click, clear_btn_full.click],
|
599 |
+
fn=lambda _: ("", "", 1, _get_dataframe()),
|
600 |
+
inputs=[dataframe],
|
601 |
+
outputs=[dataset_description, system_prompt, num_turns, dataframe],
|
602 |
+
)
|
603 |
+
app.load(fn=get_org_dropdown, outputs=[org_name])
|
604 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
605 |
+
app.load(fn=swap_visibility, outputs=main_ui)
|
src/synthetic_dataset_generator/apps/eval.py
CHANGED
@@ -889,6 +889,7 @@ with gr.Blocks() as app:
|
|
889 |
outputs=[
|
890 |
instruction_instruction_response,
|
891 |
response_instruction_response,
|
|
|
892 |
],
|
893 |
)
|
894 |
|
|
|
889 |
outputs=[
|
890 |
instruction_instruction_response,
|
891 |
response_instruction_response,
|
892 |
+
dataframe
|
893 |
],
|
894 |
)
|
895 |
|
src/synthetic_dataset_generator/apps/rag.py
CHANGED
@@ -76,7 +76,7 @@ def _load_dataset_from_hub(
|
|
76 |
progress=gr.Progress(track_tqdm=True),
|
77 |
):
|
78 |
if not repo_id:
|
79 |
-
raise gr.Error("Hub repo
|
80 |
subsets = get_dataset_config_names(repo_id, token=token)
|
81 |
splits = get_dataset_split_names(repo_id, subsets[0], token=token)
|
82 |
ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
|
@@ -101,7 +101,10 @@ def _load_dataset_from_hub(
|
|
101 |
)
|
102 |
|
103 |
|
104 |
-
def _preprocess_input_data(file_paths, num_rows, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
105 |
data = {}
|
106 |
total_chunks = 0
|
107 |
|
@@ -131,7 +134,7 @@ def _preprocess_input_data(file_paths, num_rows, progress=gr.Progress(track_tqdm
|
|
131 |
)
|
132 |
|
133 |
|
134 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
135 |
progress(0.1, desc="Initializing")
|
136 |
generate_description = get_prompt_generator()
|
137 |
progress(0.5, desc="Generating")
|
@@ -753,7 +756,6 @@ with gr.Blocks() as app:
|
|
753 |
) as pipeline_code_ui:
|
754 |
code = generate_pipeline_code(
|
755 |
repo_id=search_in.value,
|
756 |
-
file_paths=file_in.value,
|
757 |
input_type=input_type.value,
|
758 |
system_prompt=system_prompt.value,
|
759 |
document_column=document_column.value,
|
@@ -891,7 +893,6 @@ with gr.Blocks() as app:
|
|
891 |
fn=generate_pipeline_code,
|
892 |
inputs=[
|
893 |
search_in,
|
894 |
-
file_in,
|
895 |
input_type,
|
896 |
system_prompt,
|
897 |
document_column,
|
|
|
76 |
progress=gr.Progress(track_tqdm=True),
|
77 |
):
|
78 |
if not repo_id:
|
79 |
+
raise gr.Error("Please provide a Hub repo ID")
|
80 |
subsets = get_dataset_config_names(repo_id, token=token)
|
81 |
splits = get_dataset_split_names(repo_id, subsets[0], token=token)
|
82 |
ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
|
|
|
101 |
)
|
102 |
|
103 |
|
104 |
+
def _preprocess_input_data(file_paths: list[str], num_rows: int, progress=gr.Progress(track_tqdm=True)):
|
105 |
+
if not file_paths:
|
106 |
+
raise gr.Error("Please provide an input file")
|
107 |
+
|
108 |
data = {}
|
109 |
total_chunks = 0
|
110 |
|
|
|
134 |
)
|
135 |
|
136 |
|
137 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
138 |
progress(0.1, desc="Initializing")
|
139 |
generate_description = get_prompt_generator()
|
140 |
progress(0.5, desc="Generating")
|
|
|
756 |
) as pipeline_code_ui:
|
757 |
code = generate_pipeline_code(
|
758 |
repo_id=search_in.value,
|
|
|
759 |
input_type=input_type.value,
|
760 |
system_prompt=system_prompt.value,
|
761 |
document_column=document_column.value,
|
|
|
893 |
fn=generate_pipeline_code,
|
894 |
inputs=[
|
895 |
search_in,
|
|
|
896 |
input_type,
|
897 |
system_prompt,
|
898 |
document_column,
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -49,7 +49,7 @@ def _get_dataframe():
|
|
49 |
)
|
50 |
|
51 |
|
52 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
53 |
progress(0.0, desc="Starting")
|
54 |
progress(0.3, desc="Initializing")
|
55 |
generate_description = get_prompt_generator()
|
@@ -71,7 +71,12 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
71 |
|
72 |
|
73 |
def generate_sample_dataset(
|
74 |
-
system_prompt
|
|
|
|
|
|
|
|
|
|
|
75 |
):
|
76 |
dataframe = generate_dataset(
|
77 |
system_prompt=system_prompt,
|
@@ -294,14 +299,14 @@ def push_dataset(
|
|
294 |
temperature=temperature,
|
295 |
)
|
296 |
push_dataset_to_hub(
|
297 |
-
dataframe,
|
298 |
-
org_name,
|
299 |
-
repo_name,
|
300 |
-
multi_label,
|
301 |
-
labels,
|
302 |
-
oauth_token,
|
303 |
-
private,
|
304 |
-
pipeline_code,
|
305 |
)
|
306 |
|
307 |
dataframe = dataframe[
|
@@ -657,6 +662,7 @@ with gr.Blocks() as app:
|
|
657 |
"",
|
658 |
"",
|
659 |
[],
|
|
|
660 |
_get_dataframe(),
|
661 |
),
|
662 |
inputs=[dataframe],
|
|
|
49 |
)
|
50 |
|
51 |
|
52 |
+
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
|
53 |
progress(0.0, desc="Starting")
|
54 |
progress(0.3, desc="Initializing")
|
55 |
generate_description = get_prompt_generator()
|
|
|
71 |
|
72 |
|
73 |
def generate_sample_dataset(
|
74 |
+
system_prompt: str,
|
75 |
+
difficulty: str,
|
76 |
+
clarity: str,
|
77 |
+
labels: List[str],
|
78 |
+
multi_label: bool,
|
79 |
+
progress=gr.Progress(),
|
80 |
):
|
81 |
dataframe = generate_dataset(
|
82 |
system_prompt=system_prompt,
|
|
|
299 |
temperature=temperature,
|
300 |
)
|
301 |
push_dataset_to_hub(
|
302 |
+
dataframe=dataframe,
|
303 |
+
org_name=org_name,
|
304 |
+
repo_name=repo_name,
|
305 |
+
multi_label=multi_label,
|
306 |
+
labels=labels,
|
307 |
+
oauth_token=oauth_token,
|
308 |
+
private=private,
|
309 |
+
pipeline_code=pipeline_code,
|
310 |
)
|
311 |
|
312 |
dataframe = dataframe[
|
|
|
662 |
"",
|
663 |
"",
|
664 |
[],
|
665 |
+
"",
|
666 |
_get_dataframe(),
|
667 |
),
|
668 |
inputs=[dataframe],
|
src/synthetic_dataset_generator/pipelines/chat.py
CHANGED
@@ -140,7 +140,7 @@ else:
|
|
140 |
]
|
141 |
|
142 |
|
143 |
-
def _get_output_mappings(num_turns):
|
144 |
if num_turns == 1:
|
145 |
return {"instruction": "prompt", "response": "completion"}
|
146 |
else:
|
@@ -162,7 +162,7 @@ def get_prompt_generator():
|
|
162 |
return prompt_generator
|
163 |
|
164 |
|
165 |
-
def get_magpie_generator(
|
166 |
input_mappings = _get_output_mappings(num_turns)
|
167 |
output_mappings = input_mappings.copy()
|
168 |
if num_turns == 1:
|
@@ -203,7 +203,9 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
203 |
return magpie_generator
|
204 |
|
205 |
|
206 |
-
def get_response_generator(
|
|
|
|
|
207 |
if num_turns == 1:
|
208 |
generation_kwargs = {
|
209 |
"temperature": temperature,
|
@@ -229,7 +231,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
|
229 |
return response_generator
|
230 |
|
231 |
|
232 |
-
def generate_pipeline_code(system_prompt, num_turns, num_rows):
|
233 |
input_mappings = _get_output_mappings(num_turns)
|
234 |
|
235 |
code = f"""
|
|
|
140 |
]
|
141 |
|
142 |
|
143 |
+
def _get_output_mappings(num_turns: int):
|
144 |
if num_turns == 1:
|
145 |
return {"instruction": "prompt", "response": "completion"}
|
146 |
else:
|
|
|
162 |
return prompt_generator
|
163 |
|
164 |
|
165 |
+
def get_magpie_generator(num_turns: int, temperature: float, is_sample: bool):
|
166 |
input_mappings = _get_output_mappings(num_turns)
|
167 |
output_mappings = input_mappings.copy()
|
168 |
if num_turns == 1:
|
|
|
203 |
return magpie_generator
|
204 |
|
205 |
|
206 |
+
def get_response_generator(
|
207 |
+
system_prompt: str, num_turns: int, temperature: float, is_sample: bool
|
208 |
+
):
|
209 |
if num_turns == 1:
|
210 |
generation_kwargs = {
|
211 |
"temperature": temperature,
|
|
|
231 |
return response_generator
|
232 |
|
233 |
|
234 |
+
def generate_pipeline_code(system_prompt: str, num_turns: int, num_rows: int):
|
235 |
input_mappings = _get_output_mappings(num_turns)
|
236 |
|
237 |
code = f"""
|
src/synthetic_dataset_generator/pipelines/eval.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
2 |
from distilabel.models import InferenceEndpointsLLM
|
3 |
from distilabel.steps.tasks import (
|
@@ -10,7 +12,7 @@ from synthetic_dataset_generator.pipelines.base import _get_next_api_key
|
|
10 |
from synthetic_dataset_generator.utils import extract_column_names
|
11 |
|
12 |
|
13 |
-
def get_ultrafeedback_evaluator(aspect, is_sample):
|
14 |
ultrafeedback_evaluator = UltraFeedback(
|
15 |
llm=InferenceEndpointsLLM(
|
16 |
model_id=MODEL,
|
@@ -27,7 +29,9 @@ def get_ultrafeedback_evaluator(aspect, is_sample):
|
|
27 |
return ultrafeedback_evaluator
|
28 |
|
29 |
|
30 |
-
def get_custom_evaluator(
|
|
|
|
|
31 |
custom_evaluator = TextGeneration(
|
32 |
llm=InferenceEndpointsLLM(
|
33 |
model_id=MODEL,
|
@@ -47,7 +51,13 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample)
|
|
47 |
|
48 |
|
49 |
def generate_ultrafeedback_pipeline_code(
|
50 |
-
repo_id
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
):
|
52 |
if len(aspects) == 1:
|
53 |
code = f"""
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
4 |
from distilabel.models import InferenceEndpointsLLM
|
5 |
from distilabel.steps.tasks import (
|
|
|
12 |
from synthetic_dataset_generator.utils import extract_column_names
|
13 |
|
14 |
|
15 |
+
def get_ultrafeedback_evaluator(aspect: str, is_sample: bool):
|
16 |
ultrafeedback_evaluator = UltraFeedback(
|
17 |
llm=InferenceEndpointsLLM(
|
18 |
model_id=MODEL,
|
|
|
29 |
return ultrafeedback_evaluator
|
30 |
|
31 |
|
32 |
+
def get_custom_evaluator(
|
33 |
+
prompt_template: str, structured_output: dict, columns: List[str], is_sample: bool
|
34 |
+
):
|
35 |
custom_evaluator = TextGeneration(
|
36 |
llm=InferenceEndpointsLLM(
|
37 |
model_id=MODEL,
|
|
|
51 |
|
52 |
|
53 |
def generate_ultrafeedback_pipeline_code(
|
54 |
+
repo_id: str,
|
55 |
+
subset: str,
|
56 |
+
split: str,
|
57 |
+
aspects: List[str],
|
58 |
+
instruction_column: str,
|
59 |
+
response_columns: str,
|
60 |
+
num_rows: int,
|
61 |
):
|
62 |
if len(aspects) == 1:
|
63 |
code = f"""
|
src/synthetic_dataset_generator/pipelines/rag.py
CHANGED
@@ -87,7 +87,7 @@ def get_prompt_generator():
|
|
87 |
return text_generator
|
88 |
|
89 |
|
90 |
-
def get_chunks_generator(temperature, is_sample):
|
91 |
generation_kwargs = {
|
92 |
"temperature": temperature,
|
93 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
@@ -104,7 +104,7 @@ def get_chunks_generator(temperature, is_sample):
|
|
104 |
return text_generator
|
105 |
|
106 |
|
107 |
-
def get_sentence_pair_generator(action, triplet, temperature, is_sample):
|
108 |
generation_kwargs = {
|
109 |
"temperature": temperature,
|
110 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
@@ -119,7 +119,7 @@ def get_sentence_pair_generator(action, triplet, temperature, is_sample):
|
|
119 |
return sentence_pair_generator
|
120 |
|
121 |
|
122 |
-
def get_response_generator(temperature, is_sample):
|
123 |
generation_kwargs = {
|
124 |
"temperature": temperature,
|
125 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
@@ -138,7 +138,6 @@ def get_response_generator(temperature, is_sample):
|
|
138 |
|
139 |
def generate_pipeline_code(
|
140 |
repo_id: str,
|
141 |
-
file_paths: List[str],
|
142 |
input_type: str,
|
143 |
system_prompt: str,
|
144 |
document_column: str,
|
|
|
87 |
return text_generator
|
88 |
|
89 |
|
90 |
+
def get_chunks_generator(temperature: float, is_sample: bool):
|
91 |
generation_kwargs = {
|
92 |
"temperature": temperature,
|
93 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
|
|
104 |
return text_generator
|
105 |
|
106 |
|
107 |
+
def get_sentence_pair_generator(action: str, triplet: bool, temperature: float, is_sample: bool):
|
108 |
generation_kwargs = {
|
109 |
"temperature": temperature,
|
110 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
|
|
119 |
return sentence_pair_generator
|
120 |
|
121 |
|
122 |
+
def get_response_generator(temperature: float, is_sample: bool):
|
123 |
generation_kwargs = {
|
124 |
"temperature": temperature,
|
125 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
|
|
138 |
|
139 |
def generate_pipeline_code(
|
140 |
repo_id: str,
|
|
|
141 |
input_type: str,
|
142 |
system_prompt: str,
|
143 |
document_column: str,
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -85,7 +85,9 @@ def get_prompt_generator():
|
|
85 |
return prompt_generator
|
86 |
|
87 |
|
88 |
-
def get_textcat_generator(
|
|
|
|
|
89 |
generation_kwargs = {
|
90 |
"temperature": temperature,
|
91 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
@@ -102,7 +104,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
|
|
102 |
return textcat_generator
|
103 |
|
104 |
|
105 |
-
def get_labeller_generator(system_prompt, labels, multi_label):
|
106 |
generation_kwargs = {
|
107 |
"temperature": 0.01,
|
108 |
"max_new_tokens": MAX_NUM_TOKENS,
|
|
|
85 |
return prompt_generator
|
86 |
|
87 |
|
88 |
+
def get_textcat_generator(
|
89 |
+
difficulty: str, clarity: str, temperature: float, is_sample: bool
|
90 |
+
):
|
91 |
generation_kwargs = {
|
92 |
"temperature": temperature,
|
93 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
|
|
104 |
return textcat_generator
|
105 |
|
106 |
|
107 |
+
def get_labeller_generator(system_prompt: str, labels: List[str], multi_label: bool):
|
108 |
generation_kwargs = {
|
109 |
"temperature": 0.01,
|
110 |
"max_new_tokens": MAX_NUM_TOKENS,
|