Sara Han commited on
Commit
b2669f7
·
unverified ·
1 Parent(s): be0e284

feat: add seed data for chat data (#32)

Browse files

* add similar ui logic for seed data in chat

* add logic for seed generation

* remove todos

* apply feedback

* small fix

src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -1,12 +1,16 @@
1
  import io
2
  import uuid
 
3
  from typing import Union
4
 
5
  import argilla as rg
6
  import gradio as gr
7
- from datasets import Dataset, concatenate_datasets, load_dataset
 
8
  from gradio import OAuthToken
9
  from huggingface_hub import HfApi, upload_file, repo_exists
 
 
10
 
11
  from synthetic_dataset_generator.constants import MAX_NUM_ROWS
12
  from synthetic_dataset_generator.utils import get_argilla_client
@@ -179,3 +183,81 @@ def get_iframe(hub_repo_id: str) -> str:
179
  ></iframe>
180
  """
181
  return iframe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
2
  import uuid
3
+ from tqdm import tqdm
4
  from typing import Union
5
 
6
  import argilla as rg
7
  import gradio as gr
8
+ import pandas as pd
9
+ from datasets import Dataset, concatenate_datasets, get_dataset_config_names, get_dataset_split_names, load_dataset
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file, repo_exists
12
+ from unstructured.chunking.title import chunk_by_title
13
+ from unstructured.partition.auto import partition
14
 
15
  from synthetic_dataset_generator.constants import MAX_NUM_ROWS
16
  from synthetic_dataset_generator.utils import get_argilla_client
 
183
  ></iframe>
184
  """
185
  return iframe
186
+
187
+
188
+ def _get_valid_columns(dataframe: pd.DataFrame):
189
+ doc_valid_columns = []
190
+
191
+ for col in dataframe.columns:
192
+ sample_val = dataframe[col].iloc[0]
193
+ if isinstance(sample_val, str):
194
+ doc_valid_columns.append(col)
195
+
196
+ return doc_valid_columns
197
+
198
+
199
+ def load_dataset_from_hub(
200
+ repo_id: str,
201
+ num_rows: int = 10,
202
+ token: Union[OAuthToken, None] = None,
203
+ progress=gr.Progress(track_tqdm=True),
204
+ ):
205
+ if not repo_id:
206
+ raise gr.Error("Please provide a Hub repo ID")
207
+ subsets = get_dataset_config_names(repo_id, token=token)
208
+ splits = get_dataset_split_names(repo_id, subsets[0], token=token)
209
+ ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
210
+ rows = []
211
+ for idx, row in enumerate(tqdm(ds, desc="Loading the dataset", total=num_rows)):
212
+ rows.append(row)
213
+ if idx == num_rows:
214
+ break
215
+ ds = Dataset.from_list(rows)
216
+ dataframe = ds.to_pandas()
217
+ doc_valid_columns = _get_valid_columns(dataframe)
218
+ col_doc = doc_valid_columns[0] if doc_valid_columns else ""
219
+ return (
220
+ dataframe,
221
+ gr.Dropdown(
222
+ choices=doc_valid_columns,
223
+ label="Documents column",
224
+ value=col_doc,
225
+ interactive=(False if col_doc == "" else True),
226
+ multiselect=False,
227
+ ),
228
+ )
229
+
230
+
231
+ def preprocess_input_data(
232
+ file_paths: list[str], num_rows: int, progress=gr.Progress(track_tqdm=True)
233
+ ):
234
+ if not file_paths:
235
+ raise gr.Error("Please provide an input file")
236
+
237
+ data = {}
238
+ total_chunks = 0
239
+
240
+ for file_path in tqdm(file_paths, desc="Processing files", total=len(file_paths)):
241
+ partitioned_file = partition(filename=file_path)
242
+ chunks = [str(chunk) for chunk in chunk_by_title(partitioned_file)]
243
+ data[file_path] = chunks
244
+ total_chunks += len(chunks)
245
+ if total_chunks >= num_rows:
246
+ break
247
+
248
+ dataframe = pd.DataFrame.from_records(
249
+ [(k, v) for k, values in data.items() for v in values],
250
+ columns=["filename", "chunks"],
251
+ )
252
+ col_doc = "chunks"
253
+
254
+ return (
255
+ dataframe,
256
+ gr.Dropdown(
257
+ choices=["chunks"],
258
+ label="Documents column",
259
+ value=col_doc,
260
+ interactive=(False if col_doc == "" else True),
261
+ multiselect=False,
262
+ ),
263
+ )
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -1,4 +1,5 @@
1
  import ast
 
2
  import random
3
  import uuid
4
  from typing import Dict, List, Union
@@ -8,11 +9,15 @@ import gradio as gr
8
  import pandas as pd
9
  from datasets import Dataset
10
  from distilabel.distiset import Distiset
 
 
11
  from huggingface_hub import HfApi
12
 
13
  from synthetic_dataset_generator.apps.base import (
14
  combine_datasets,
15
  hide_success_message,
 
 
16
  push_pipeline_code_to_hub,
17
  show_success_message,
18
  test_max_num_rows,
@@ -29,15 +34,18 @@ from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
  from synthetic_dataset_generator.pipelines.chat import (
30
  DEFAULT_DATASET_DESCRIPTIONS,
31
  generate_pipeline_code,
 
32
  get_magpie_generator,
33
  get_prompt_generator,
34
  get_response_generator,
 
35
  )
36
  from synthetic_dataset_generator.pipelines.embeddings import (
37
  get_embeddings,
38
  get_sentence_embedding_dimensions,
39
  )
40
  from synthetic_dataset_generator.utils import (
 
41
  get_argilla_client,
42
  get_org_dropdown,
43
  get_random_repo_name,
@@ -45,6 +53,14 @@ from synthetic_dataset_generator.utils import (
45
  )
46
 
47
 
 
 
 
 
 
 
 
 
48
  def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
49
  def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
50
  return ast.literal_eval(
@@ -77,28 +93,57 @@ def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
77
  return result
78
 
79
 
80
- def generate_sample_dataset(system_prompt: str, num_turns: int, progress=gr.Progress()):
81
- progress(0.1, desc="Generating sample dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  dataframe = generate_dataset(
 
 
83
  system_prompt=system_prompt,
 
84
  num_turns=num_turns,
85
- num_rows=10,
86
- progress=progress,
87
  is_sample=True,
88
  )
89
  progress(1.0, desc="Sample dataset generated")
90
  return dataframe
91
 
92
 
93
- def _get_dataframe():
94
- return gr.Dataframe(
95
- headers=["prompt", "completion"],
96
- wrap=True,
97
- interactive=False,
98
- )
99
-
100
-
101
- def generate_dataset(
102
  system_prompt: str,
103
  num_turns: int = 1,
104
  num_rows: int = 10,
@@ -108,9 +153,7 @@ def generate_dataset(
108
  ) -> pd.DataFrame:
109
  num_rows = test_max_num_rows(num_rows)
110
  progress(0.0, desc="(1/2) Generating instructions")
111
- magpie_generator = get_magpie_generator(
112
- num_turns, temperature, is_sample
113
- )
114
  response_generator = get_response_generator(
115
  system_prompt, num_turns, temperature, is_sample
116
  )
@@ -217,6 +260,171 @@ def generate_dataset(
217
  return dataframe
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def push_dataset_to_hub(
221
  dataframe: pd.DataFrame,
222
  org_name: str,
@@ -251,17 +459,35 @@ def push_dataset_to_hub(
251
  def push_dataset(
252
  org_name: str,
253
  repo_name: str,
 
 
 
 
254
  system_prompt: str,
 
255
  num_turns: int = 1,
256
  num_rows: int = 10,
257
- private: bool = False,
258
  temperature: float = 0.9,
259
  pipeline_code: str = "",
260
  oauth_token: Union[gr.OAuthToken, None] = None,
261
  progress=gr.Progress(),
262
  ) -> pd.DataFrame:
 
 
 
 
 
 
 
 
 
 
 
263
  dataframe = generate_dataset(
 
 
264
  system_prompt=system_prompt,
 
265
  num_turns=num_turns,
266
  num_rows=num_rows,
267
  temperature=temperature,
@@ -395,6 +621,28 @@ def push_dataset(
395
  return ""
396
 
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  def show_pipeline_code_visibility():
399
  return {pipeline_code_ui: gr.Accordion(visible=True)}
400
 
@@ -422,29 +670,85 @@ with gr.Blocks() as app:
422
  )
423
  )
424
  else:
425
- gr.Markdown(value="## 1. Describe the dataset you want")
426
- with gr.Row():
427
  with gr.Column(scale=2):
428
- dataset_description = gr.Textbox(
429
- label="Dataset description",
430
- placeholder="Give a precise description of your desired dataset.",
431
- )
432
- with gr.Row():
433
- clear_btn_part = gr.Button(
434
- "Clear",
435
- variant="secondary",
436
- )
437
- load_btn = gr.Button(
438
- "Create",
439
- variant="primary",
440
- )
441
- with gr.Column(scale=3):
442
- examples = gr.Examples(
443
- examples=DEFAULT_DATASET_DESCRIPTIONS,
444
- inputs=[dataset_description],
445
- cache_examples=False,
446
- label="Examples",
447
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
  gr.HTML(value="<hr>")
450
  gr.Markdown(value="## 2. Configure your dataset")
@@ -454,6 +758,16 @@ with gr.Blocks() as app:
454
  label="System prompt",
455
  placeholder="You are a helpful assistant.",
456
  )
 
 
 
 
 
 
 
 
 
 
457
  num_turns = gr.Number(
458
  value=1,
459
  label="Number of turns in the conversation",
@@ -519,7 +833,10 @@ with gr.Blocks() as app:
519
  visible=False,
520
  ) as pipeline_code_ui:
521
  code = generate_pipeline_code(
 
 
522
  system_prompt=system_prompt.value,
 
523
  num_turns=num_turns.value,
524
  num_rows=num_rows.value,
525
  )
@@ -529,77 +846,137 @@ with gr.Blocks() as app:
529
  label="Distilabel Pipeline Code",
530
  )
531
 
532
- load_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  fn=generate_system_prompt,
534
  inputs=[dataset_description],
535
  outputs=[system_prompt],
536
- show_progress=True,
537
- ).then(
538
  fn=generate_sample_dataset,
539
- inputs=[system_prompt, num_turns],
540
- outputs=[dataframe],
541
- show_progress=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
  )
543
 
544
  btn_apply_to_sample_dataset.click(
545
  fn=generate_sample_dataset,
546
- inputs=[system_prompt, num_turns],
547
- outputs=[dataframe],
548
- show_progress=True,
 
 
 
 
 
 
 
549
  )
550
 
551
  btn_push_to_hub.click(
552
  fn=validate_argilla_user_workspace_dataset,
553
  inputs=[repo_name],
554
  outputs=[success_message],
555
- show_progress=True,
556
  ).then(
557
  fn=validate_push_to_hub,
558
  inputs=[org_name, repo_name],
559
  outputs=[success_message],
560
- show_progress=True,
561
  ).success(
562
  fn=hide_success_message,
563
  outputs=[success_message],
564
- show_progress=True,
565
  ).success(
566
  fn=hide_pipeline_code_visibility,
567
  inputs=[],
568
  outputs=[pipeline_code_ui],
569
- show_progress=True,
570
  ).success(
571
  fn=push_dataset,
572
  inputs=[
573
  org_name,
574
  repo_name,
 
 
 
 
575
  system_prompt,
 
576
  num_turns,
577
  num_rows,
578
- private,
579
  temperature,
580
  pipeline_code,
581
  ],
582
  outputs=[success_message],
583
- show_progress=True,
584
  ).success(
585
  fn=show_success_message,
586
  inputs=[org_name, repo_name],
587
  outputs=[success_message],
588
  ).success(
589
  fn=generate_pipeline_code,
590
- inputs=[system_prompt, num_turns, num_rows],
 
 
 
 
 
 
 
591
  outputs=[pipeline_code],
592
  ).success(
593
  fn=show_pipeline_code_visibility,
594
  inputs=[],
595
  outputs=[pipeline_code_ui],
596
  )
597
- gr.on(
598
- triggers=[clear_btn_part.click, clear_btn_full.click],
599
- fn=lambda _: ("", "", 1, _get_dataframe()),
 
 
 
600
  inputs=[dataframe],
601
- outputs=[dataset_description, system_prompt, num_turns, dataframe],
602
  )
 
 
603
  app.load(fn=get_org_dropdown, outputs=[org_name])
604
  app.load(fn=get_random_repo_name, outputs=[repo_name])
605
- app.load(fn=swap_visibility, outputs=main_ui)
 
1
  import ast
2
+ import json
3
  import random
4
  import uuid
5
  from typing import Dict, List, Union
 
9
  import pandas as pd
10
  from datasets import Dataset
11
  from distilabel.distiset import Distiset
12
+ from gradio.oauth import OAuthToken
13
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
  from huggingface_hub import HfApi
15
 
16
  from synthetic_dataset_generator.apps.base import (
17
  combine_datasets,
18
  hide_success_message,
19
+ load_dataset_from_hub,
20
+ preprocess_input_data,
21
  push_pipeline_code_to_hub,
22
  show_success_message,
23
  test_max_num_rows,
 
34
  from synthetic_dataset_generator.pipelines.chat import (
35
  DEFAULT_DATASET_DESCRIPTIONS,
36
  generate_pipeline_code,
37
+ get_follow_up_generator,
38
  get_magpie_generator,
39
  get_prompt_generator,
40
  get_response_generator,
41
+ get_sentence_pair_generator,
42
  )
43
  from synthetic_dataset_generator.pipelines.embeddings import (
44
  get_embeddings,
45
  get_sentence_embedding_dimensions,
46
  )
47
  from synthetic_dataset_generator.utils import (
48
+ column_to_list,
49
  get_argilla_client,
50
  get_org_dropdown,
51
  get_random_repo_name,
 
53
  )
54
 
55
 
56
+ def _get_dataframe():
57
+ return gr.Dataframe(
58
+ headers=["prompt", "completion"],
59
+ wrap=True,
60
+ interactive=False,
61
+ )
62
+
63
+
64
  def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
65
  def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
66
  return ast.literal_eval(
 
93
  return result
94
 
95
 
96
+ def load_dataset_file(
97
+ repo_id: str,
98
+ file_paths: list[str],
99
+ input_type: str,
100
+ num_rows: int = 10,
101
+ token: Union[OAuthToken, None] = None,
102
+ progress=gr.Progress(),
103
+ ):
104
+ progress(0.1, desc="Loading the source data")
105
+ if input_type == "dataset-input":
106
+ return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token)
107
+ else:
108
+ return preprocess_input_data(file_paths=file_paths, num_rows=num_rows)
109
+
110
+
111
+ def generate_sample_dataset(
112
+ repo_id: str,
113
+ file_paths: list[str],
114
+ input_type: str,
115
+ system_prompt: str,
116
+ document_column: str,
117
+ num_turns: int,
118
+ num_rows: int,
119
+ oauth_token: Union[OAuthToken, None],
120
+ progress=gr.Progress(),
121
+ ):
122
+ if input_type == "prompt-input":
123
+ dataframe = pd.DataFrame(columns=["prompt", "completion"])
124
+ else:
125
+ dataframe, _ = load_dataset_file(
126
+ repo_id=repo_id,
127
+ file_paths=file_paths,
128
+ input_type=input_type,
129
+ num_rows=num_rows,
130
+ token=oauth_token,
131
+ )
132
+ progress(0.5, desc="Generating sample dataset")
133
  dataframe = generate_dataset(
134
+ input_type=input_type,
135
+ dataframe=dataframe,
136
  system_prompt=system_prompt,
137
+ document_column=document_column,
138
  num_turns=num_turns,
139
+ num_rows=num_rows,
 
140
  is_sample=True,
141
  )
142
  progress(1.0, desc="Sample dataset generated")
143
  return dataframe
144
 
145
 
146
+ def generate_dataset_from_prompt(
 
 
 
 
 
 
 
 
147
  system_prompt: str,
148
  num_turns: int = 1,
149
  num_rows: int = 10,
 
153
  ) -> pd.DataFrame:
154
  num_rows = test_max_num_rows(num_rows)
155
  progress(0.0, desc="(1/2) Generating instructions")
156
+ magpie_generator = get_magpie_generator(num_turns, temperature, is_sample)
 
 
157
  response_generator = get_response_generator(
158
  system_prompt, num_turns, temperature, is_sample
159
  )
 
260
  return dataframe
261
 
262
 
263
+ def generate_dataset_from_seed(
264
+ dataframe: pd.DataFrame,
265
+ document_column: str,
266
+ num_turns: int = 1,
267
+ num_rows: int = 10,
268
+ temperature: float = 0.9,
269
+ is_sample: bool = False,
270
+ progress=gr.Progress(),
271
+ ) -> pd.DataFrame:
272
+ num_rows = test_max_num_rows(num_rows)
273
+ progress(0.0, desc="Initializing dataset generation")
274
+ document_data = column_to_list(dataframe, document_column)
275
+ if len(document_data) < num_rows:
276
+ document_data += random.choices(document_data, k=num_rows - len(document_data))
277
+ instruction_generator = get_sentence_pair_generator(
278
+ temperature=temperature, is_sample=is_sample
279
+ )
280
+ response_generator = get_response_generator(
281
+ system_prompt=None, num_turns=1, temperature=temperature, is_sample=is_sample
282
+ )
283
+ follow_up_generator_instruction = get_follow_up_generator(
284
+ type="instruction", temperature=temperature, is_sample=is_sample
285
+ )
286
+ follow_up_generator_response = get_follow_up_generator(
287
+ type="response", temperature=temperature, is_sample=is_sample
288
+ )
289
+ steps = 2 * num_turns
290
+ total_steps: int = num_rows * steps
291
+ step_progress = round(1 / steps, 2)
292
+ batch_size = DEFAULT_BATCH_SIZE
293
+
294
+ # create instructions
295
+ n_processed = 0
296
+ instruction_results = []
297
+ while n_processed < num_rows:
298
+ progress(
299
+ step_progress * n_processed / num_rows,
300
+ total=total_steps,
301
+ desc="Generating questions",
302
+ )
303
+ remaining_rows = num_rows - n_processed
304
+ batch_size = min(batch_size, remaining_rows)
305
+ batch = [
306
+ {"anchor": document}
307
+ for document in document_data[n_processed : n_processed + batch_size]
308
+ ]
309
+ questions = list(instruction_generator.process(inputs=batch))
310
+ instruction_results.extend(questions[0])
311
+ n_processed += batch_size
312
+ for result in instruction_results:
313
+ result["instruction"] = result["positive"]
314
+ result["prompt"] = result.pop("positive")
315
+
316
+ progress(step_progress, desc="Generating instructions")
317
+
318
+ # generate responses
319
+ n_processed = 0
320
+ response_results = []
321
+ while n_processed < num_rows:
322
+ progress(
323
+ step_progress + step_progress * n_processed / num_rows,
324
+ total=total_steps,
325
+ desc="Generating responses",
326
+ )
327
+ batch = instruction_results[n_processed : n_processed + batch_size]
328
+ responses = list(response_generator.process(inputs=batch))
329
+ response_results.extend(responses[0])
330
+ n_processed += batch_size
331
+ for result in response_results:
332
+ result["completion"] = result.pop("generation")
333
+
334
+ # generate follow-ups
335
+ if num_turns > 1:
336
+ n_processed = 0
337
+ final_conversations = []
338
+
339
+ while n_processed < num_rows:
340
+ progress(
341
+ step_progress + step_progress * n_processed / num_rows,
342
+ total=total_steps,
343
+ desc="Generating follow-ups",
344
+ )
345
+ batch = response_results[n_processed : n_processed + batch_size]
346
+ conversations_batch = [
347
+ {
348
+ "messages": [
349
+ {"role": "user", "content": result["prompt"]},
350
+ {"role": "assistant", "content": result["completion"]},
351
+ ]
352
+ }
353
+ for result in batch
354
+ ]
355
+
356
+ for _ in range(num_turns - 1):
357
+ follow_up_instructions = list(
358
+ follow_up_generator_instruction.process(inputs=conversations_batch)
359
+ )
360
+ for conv, follow_up in zip(conversations_batch, follow_up_instructions[0]):
361
+ conv["messages"].append(
362
+ {"role": "user", "content": follow_up["generation"]}
363
+ )
364
+
365
+ follow_up_responses = list(
366
+ follow_up_generator_response.process(inputs=conversations_batch)
367
+ )
368
+ for conv, follow_up in zip(conversations_batch, follow_up_responses[0]):
369
+ conv["messages"].append(
370
+ {"role": "assistant", "content": follow_up["generation"]}
371
+ )
372
+
373
+ final_conversations.extend(
374
+ [{"messages": conv["messages"]} for conv in conversations_batch]
375
+ )
376
+ n_processed += batch_size
377
+
378
+ # create distiset
379
+ distiset_results = []
380
+ if num_turns == 1:
381
+ for result in response_results:
382
+ record = {}
383
+ for relevant_keys in ["prompt", "completion"]:
384
+ if relevant_keys in result:
385
+ record[relevant_keys] = result[relevant_keys]
386
+ distiset_results.append(record)
387
+ dataframe = pd.DataFrame(distiset_results)
388
+ else:
389
+ distiset_results = final_conversations
390
+ dataframe = pd.DataFrame(distiset_results)
391
+ dataframe["messages"] = dataframe["messages"].apply(lambda x: json.dumps(x))
392
+
393
+ progress(1.0, desc="Dataset generation completed")
394
+ return dataframe
395
+
396
+
397
+ def generate_dataset(
398
+ input_type: str,
399
+ dataframe: pd.DataFrame,
400
+ system_prompt: str,
401
+ document_column: str,
402
+ num_turns: int = 1,
403
+ num_rows: int = 10,
404
+ temperature: float = 0.9,
405
+ is_sample: bool = False,
406
+ progress=gr.Progress(),
407
+ ) -> pd.DataFrame:
408
+ if input_type == "prompt-input":
409
+ dataframe = generate_dataset_from_prompt(
410
+ system_prompt=system_prompt,
411
+ num_turns=num_turns,
412
+ num_rows=num_rows,
413
+ temperature=temperature,
414
+ is_sample=is_sample,
415
+ )
416
+ else:
417
+ dataframe = generate_dataset_from_seed(
418
+ dataframe=dataframe,
419
+ document_column=document_column,
420
+ num_turns=num_turns,
421
+ num_rows=num_rows,
422
+ temperature=temperature,
423
+ is_sample=is_sample,
424
+ )
425
+ return dataframe
426
+
427
+
428
  def push_dataset_to_hub(
429
  dataframe: pd.DataFrame,
430
  org_name: str,
 
459
  def push_dataset(
460
  org_name: str,
461
  repo_name: str,
462
+ private: bool,
463
+ original_repo_id: str,
464
+ file_paths: list[str],
465
+ input_type: str,
466
  system_prompt: str,
467
+ document_column: str,
468
  num_turns: int = 1,
469
  num_rows: int = 10,
 
470
  temperature: float = 0.9,
471
  pipeline_code: str = "",
472
  oauth_token: Union[gr.OAuthToken, None] = None,
473
  progress=gr.Progress(),
474
  ) -> pd.DataFrame:
475
+ if input_type == "prompt-input":
476
+ dataframe = _get_dataframe()
477
+ else:
478
+ dataframe, _ = load_dataset_file(
479
+ repo_id=original_repo_id,
480
+ file_paths=file_paths,
481
+ input_type=input_type,
482
+ num_rows=num_rows,
483
+ token=oauth_token,
484
+ )
485
+ progress(0.5, desc="Generating dataset")
486
  dataframe = generate_dataset(
487
+ input_type=input_type,
488
+ dataframe=dataframe,
489
  system_prompt=system_prompt,
490
+ document_column=document_column,
491
  num_turns=num_turns,
492
  num_rows=num_rows,
493
  temperature=temperature,
 
621
  return ""
622
 
623
 
624
+ def show_system_prompt_visibility():
625
+ return {system_prompt: gr.Textbox(visible=True)}
626
+
627
+
628
+ def hide_system_prompt_visibility():
629
+ return {system_prompt: gr.Textbox(visible=False)}
630
+
631
+
632
+ def show_document_column_visibility():
633
+ return {document_column: gr.Dropdown(visible=True)}
634
+
635
+
636
+ def hide_document_column_visibility():
637
+ return {
638
+ document_column: gr.Dropdown(
639
+ choices=["Load your data first in step 1."],
640
+ value="Load your data first in step 1.",
641
+ visible=False,
642
+ )
643
+ }
644
+
645
+
646
  def show_pipeline_code_visibility():
647
  return {pipeline_code_ui: gr.Accordion(visible=True)}
648
 
 
670
  )
671
  )
672
  else:
673
+ gr.Markdown("## 1. Select your input")
674
+ with gr.Row(equal_height=False):
675
  with gr.Column(scale=2):
676
+ input_type = gr.Dropdown(
677
+ label="Input type",
678
+ choices=["prompt-input", "dataset-input", "file-input"],
679
+ value="prompt-input",
680
+ multiselect=False,
681
+ visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  )
683
+ with gr.Tab("Generate from prompt") as tab_prompt_input:
684
+ with gr.Row(equal_height=False):
685
+ with gr.Column(scale=2):
686
+ dataset_description = gr.Textbox(
687
+ label="Dataset description",
688
+ placeholder="Give a precise description of your desired dataset.",
689
+ )
690
+ with gr.Row():
691
+ clear_prompt_btn_part = gr.Button(
692
+ "Clear", variant="secondary"
693
+ )
694
+ load_prompt_btn = gr.Button(
695
+ "Create", variant="primary"
696
+ )
697
+ with gr.Column(scale=3):
698
+ examples = gr.Examples(
699
+ examples=DEFAULT_DATASET_DESCRIPTIONS,
700
+ inputs=[dataset_description],
701
+ cache_examples=False,
702
+ label="Examples",
703
+ )
704
+ with gr.Tab("Load from Hub") as tab_dataset_input:
705
+ with gr.Row(equal_height=False):
706
+ with gr.Column(scale=2):
707
+ search_in = HuggingfaceHubSearch(
708
+ label="Search",
709
+ placeholder="Search for a dataset",
710
+ search_type="dataset",
711
+ sumbit_on_select=True,
712
+ )
713
+ with gr.Row():
714
+ clear_dataset_btn_part = gr.Button(
715
+ "Clear", variant="secondary"
716
+ )
717
+ load_dataset_btn = gr.Button(
718
+ "Load", variant="primary"
719
+ )
720
+ with gr.Column(scale=3):
721
+ examples = gr.Examples(
722
+ examples=[
723
+ "charris/wikipedia_sample",
724
+ "plaguss/argilla_sdk_docs_raw_unstructured",
725
+ "BeIR/hotpotqa-generated-queries",
726
+ ],
727
+ label="Example datasets",
728
+ fn=lambda x: x,
729
+ inputs=[search_in],
730
+ run_on_click=True,
731
+ )
732
+ search_out = gr.HTML(
733
+ label="Dataset preview", visible=False
734
+ )
735
+ with gr.Tab("Load your file") as tab_file_input:
736
+ with gr.Row(equal_height=False):
737
+ with gr.Column(scale=2):
738
+ file_in = gr.File(
739
+ label="Upload your file. Supported formats: .md, .txt, .docx, .pdf",
740
+ file_count="multiple",
741
+ file_types=[".md", ".txt", ".docx", ".pdf"],
742
+ )
743
+ with gr.Row():
744
+ clear_file_btn_part = gr.Button(
745
+ "Clear", variant="secondary"
746
+ )
747
+ load_file_btn = gr.Button("Load", variant="primary")
748
+ with gr.Column(scale=3):
749
+ file_out = gr.HTML(
750
+ label="Dataset preview", visible=False
751
+ )
752
 
753
  gr.HTML(value="<hr>")
754
  gr.Markdown(value="## 2. Configure your dataset")
 
758
  label="System prompt",
759
  placeholder="You are a helpful assistant.",
760
  )
761
+ document_column = gr.Dropdown(
762
+ label="Document Column",
763
+ info="Select the document column to generate the RAG dataset",
764
+ choices=["Load your data first in step 1."],
765
+ value="Load your data first in step 1.",
766
+ interactive=False,
767
+ multiselect=False,
768
+ allow_custom_value=False,
769
+ visible=False,
770
+ )
771
  num_turns = gr.Number(
772
  value=1,
773
  label="Number of turns in the conversation",
 
833
  visible=False,
834
  ) as pipeline_code_ui:
835
  code = generate_pipeline_code(
836
+ repo_id=search_in.value,
837
+ input_type=input_type.value,
838
  system_prompt=system_prompt.value,
839
+ document_column=document_column.value,
840
  num_turns=num_turns.value,
841
  num_rows=num_rows.value,
842
  )
 
846
  label="Distilabel Pipeline Code",
847
  )
848
 
849
+ tab_prompt_input.select(
850
+ fn=lambda: "prompt-input",
851
+ inputs=[],
852
+ outputs=[input_type],
853
+ ).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
854
+ fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
855
+ )
856
+
857
+ tab_dataset_input.select(
858
+ fn=lambda: "dataset-input",
859
+ inputs=[],
860
+ outputs=[input_type],
861
+ ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
862
+ fn=show_document_column_visibility, inputs=[], outputs=[document_column]
863
+ )
864
+
865
+ tab_file_input.select(
866
+ fn=lambda: "file-input",
867
+ inputs=[],
868
+ outputs=[input_type],
869
+ ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
870
+ fn=show_document_column_visibility, inputs=[], outputs=[document_column]
871
+ )
872
+
873
+ search_in.submit(
874
+ fn=lambda df: pd.DataFrame(columns=df.columns),
875
+ inputs=[dataframe],
876
+ outputs=[dataframe],
877
+ )
878
+
879
+ load_prompt_btn.click(
880
  fn=generate_system_prompt,
881
  inputs=[dataset_description],
882
  outputs=[system_prompt],
883
+ ).success(
 
884
  fn=generate_sample_dataset,
885
+ inputs=[
886
+ search_in,
887
+ file_in,
888
+ input_type,
889
+ system_prompt,
890
+ document_column,
891
+ num_turns,
892
+ num_rows,
893
+ ],
894
+ outputs=dataframe,
895
+ )
896
+
897
+ gr.on(
898
+ triggers=[load_dataset_btn.click, load_file_btn.click],
899
+ fn=load_dataset_file,
900
+ inputs=[search_in, file_in, input_type],
901
+ outputs=[dataframe, document_column],
902
  )
903
 
904
  btn_apply_to_sample_dataset.click(
905
  fn=generate_sample_dataset,
906
+ inputs=[
907
+ search_in,
908
+ file_in,
909
+ input_type,
910
+ system_prompt,
911
+ document_column,
912
+ num_turns,
913
+ num_rows,
914
+ ],
915
+ outputs=dataframe,
916
  )
917
 
918
  btn_push_to_hub.click(
919
  fn=validate_argilla_user_workspace_dataset,
920
  inputs=[repo_name],
921
  outputs=[success_message],
 
922
  ).then(
923
  fn=validate_push_to_hub,
924
  inputs=[org_name, repo_name],
925
  outputs=[success_message],
 
926
  ).success(
927
  fn=hide_success_message,
928
  outputs=[success_message],
 
929
  ).success(
930
  fn=hide_pipeline_code_visibility,
931
  inputs=[],
932
  outputs=[pipeline_code_ui],
 
933
  ).success(
934
  fn=push_dataset,
935
  inputs=[
936
  org_name,
937
  repo_name,
938
+ private,
939
+ search_in,
940
+ file_in,
941
+ input_type,
942
  system_prompt,
943
+ document_column,
944
  num_turns,
945
  num_rows,
 
946
  temperature,
947
  pipeline_code,
948
  ],
949
  outputs=[success_message],
 
950
  ).success(
951
  fn=show_success_message,
952
  inputs=[org_name, repo_name],
953
  outputs=[success_message],
954
  ).success(
955
  fn=generate_pipeline_code,
956
+ inputs=[
957
+ search_in,
958
+ input_type,
959
+ system_prompt,
960
+ document_column,
961
+ num_turns,
962
+ num_rows,
963
+ ],
964
  outputs=[pipeline_code],
965
  ).success(
966
  fn=show_pipeline_code_visibility,
967
  inputs=[],
968
  outputs=[pipeline_code_ui],
969
  )
970
+
971
+ clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
972
+ clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
973
+ clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
974
+ clear_btn_full.click(
975
+ fn=lambda df: ("", "", [], _get_dataframe()),
976
  inputs=[dataframe],
977
+ outputs=[system_prompt, document_column, num_turns, dataframe],
978
  )
979
+
980
+ app.load(fn=swap_visibility, outputs=main_ui)
981
  app.load(fn=get_org_dropdown, outputs=[org_name])
982
  app.load(fn=get_random_repo_name, outputs=[repo_name])
 
src/synthetic_dataset_generator/apps/eval.py CHANGED
@@ -15,7 +15,7 @@ from datasets import (
15
  from distilabel.distiset import Distiset
16
  from gradio.oauth import OAuthToken #
17
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
18
- from huggingface_hub import HfApi, repo_exists
19
 
20
  from synthetic_dataset_generator.apps.base import (
21
  combine_datasets,
@@ -130,9 +130,9 @@ def load_dataset_from_hub(
130
  choices=response_valid_columns,
131
  label="Response column",
132
  value=col_response,
133
- interactive=False
134
- if col_response == "No valid response columns found."
135
- else True,
136
  ),
137
  prompt_template,
138
  structured_output,
@@ -831,16 +831,13 @@ with gr.Blocks() as app:
831
  fn=validate_argilla_user_workspace_dataset,
832
  inputs=[repo_name],
833
  outputs=[success_message],
834
- show_progress=True,
835
  ).then(
836
  fn=validate_push_to_hub,
837
  inputs=[org_name, repo_name],
838
  outputs=[success_message],
839
- show_progress=True,
840
  ).success(
841
  fn=hide_success_message,
842
  outputs=[success_message],
843
- show_progress=True,
844
  ).success(
845
  fn=hide_pipeline_code_visibility,
846
  inputs=[],
@@ -862,7 +859,6 @@ with gr.Blocks() as app:
862
  pipeline_code,
863
  ],
864
  outputs=[success_message],
865
- show_progress=True,
866
  ).success(
867
  fn=show_success_message,
868
  inputs=[org_name, repo_name],
@@ -882,14 +878,14 @@ with gr.Blocks() as app:
882
  outputs=[pipeline_code_ui],
883
  )
884
 
885
- clear_btn_part.click(fn=lambda : "", inputs=[], outputs=[search_in])
886
  clear_btn_full.click(
887
  fn=lambda df: ("", "", pd.DataFrame(columns=df.columns)),
888
  inputs=[dataframe],
889
  outputs=[
890
  instruction_instruction_response,
891
  response_instruction_response,
892
- dataframe
893
  ],
894
  )
895
 
 
15
  from distilabel.distiset import Distiset
16
  from gradio.oauth import OAuthToken #
17
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
18
+ from huggingface_hub import HfApi
19
 
20
  from synthetic_dataset_generator.apps.base import (
21
  combine_datasets,
 
130
  choices=response_valid_columns,
131
  label="Response column",
132
  value=col_response,
133
+ interactive=(
134
+ False if col_response == "No valid response columns found." else True
135
+ ),
136
  ),
137
  prompt_template,
138
  structured_output,
 
831
  fn=validate_argilla_user_workspace_dataset,
832
  inputs=[repo_name],
833
  outputs=[success_message],
 
834
  ).then(
835
  fn=validate_push_to_hub,
836
  inputs=[org_name, repo_name],
837
  outputs=[success_message],
 
838
  ).success(
839
  fn=hide_success_message,
840
  outputs=[success_message],
 
841
  ).success(
842
  fn=hide_pipeline_code_visibility,
843
  inputs=[],
 
859
  pipeline_code,
860
  ],
861
  outputs=[success_message],
 
862
  ).success(
863
  fn=show_success_message,
864
  inputs=[org_name, repo_name],
 
878
  outputs=[pipeline_code_ui],
879
  )
880
 
881
+ clear_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
882
  clear_btn_full.click(
883
  fn=lambda df: ("", "", pd.DataFrame(columns=df.columns)),
884
  inputs=[dataframe],
885
  outputs=[
886
  instruction_instruction_response,
887
  response_instruction_response,
888
+ dataframe,
889
  ],
890
  )
891
 
src/synthetic_dataset_generator/apps/rag.py CHANGED
@@ -1,30 +1,23 @@
1
  import os
2
  import random
3
  import uuid
4
- from tqdm import tqdm
5
  from typing import Union
6
 
7
  import argilla as rg
8
  import gradio as gr
9
  import nltk
10
  import pandas as pd
11
- from datasets import (
12
- Dataset,
13
- get_dataset_config_names,
14
- get_dataset_split_names,
15
- load_dataset,
16
- )
17
  from distilabel.distiset import Distiset
18
  from gradio.oauth import OAuthToken
19
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
20
  from huggingface_hub import HfApi
21
- from unstructured.chunking.title import chunk_by_title
22
- from unstructured.partition.auto import partition
23
 
24
  from synthetic_dataset_generator.apps.base import (
25
  combine_datasets,
26
- get_iframe,
27
  hide_success_message,
 
 
28
  push_pipeline_code_to_hub,
29
  show_success_message,
30
  test_max_num_rows,
@@ -39,11 +32,11 @@ from synthetic_dataset_generator.pipelines.embeddings import (
39
  )
40
  from synthetic_dataset_generator.pipelines.rag import (
41
  DEFAULT_DATASET_DESCRIPTIONS,
 
42
  get_chunks_generator,
43
  get_prompt_generator,
44
- generate_pipeline_code,
45
- get_sentence_pair_generator,
46
  get_response_generator,
 
47
  )
48
  from synthetic_dataset_generator.utils import (
49
  column_to_list,
@@ -58,81 +51,6 @@ nltk.data.path.append("./nltk_data")
58
  nltk.download("punkt_tab", download_dir="./nltk_data")
59
  nltk.download("averaged_perceptron_tagger_eng", download_dir="./nltk_data")
60
 
61
- def _get_valid_columns(dataframe: pd.DataFrame):
62
- doc_valid_columns = []
63
-
64
- for col in dataframe.columns:
65
- sample_val = dataframe[col].iloc[0]
66
- if isinstance(sample_val, str):
67
- doc_valid_columns.append(col)
68
-
69
- return doc_valid_columns
70
-
71
-
72
- def _load_dataset_from_hub(
73
- repo_id: str,
74
- num_rows: int = 10,
75
- token: Union[OAuthToken, None] = None,
76
- progress=gr.Progress(track_tqdm=True),
77
- ):
78
- if not repo_id:
79
- raise gr.Error("Please provide a Hub repo ID")
80
- subsets = get_dataset_config_names(repo_id, token=token)
81
- splits = get_dataset_split_names(repo_id, subsets[0], token=token)
82
- ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
83
- rows = []
84
- for idx, row in enumerate(tqdm(ds, desc="Loading the dataset", total=num_rows)):
85
- rows.append(row)
86
- if idx == num_rows:
87
- break
88
- ds = Dataset.from_list(rows)
89
- dataframe = ds.to_pandas()
90
- doc_valid_columns = _get_valid_columns(dataframe)
91
- col_doc = doc_valid_columns[0] if doc_valid_columns else ""
92
- return (
93
- dataframe,
94
- gr.Dropdown(
95
- choices=doc_valid_columns,
96
- label="Documents column",
97
- value=col_doc,
98
- interactive=(False if col_doc == "" else True),
99
- multiselect=False,
100
- ),
101
- )
102
-
103
-
104
- def _preprocess_input_data(file_paths: list[str], num_rows: int, progress=gr.Progress(track_tqdm=True)):
105
- if not file_paths:
106
- raise gr.Error("Please provide an input file")
107
-
108
- data = {}
109
- total_chunks = 0
110
-
111
- for file_path in tqdm(file_paths, desc="Processing files", total=len(file_paths)):
112
- partitioned_file = partition(filename=file_path)
113
- chunks = [str(chunk) for chunk in chunk_by_title(partitioned_file)]
114
- data[file_path] = chunks
115
- total_chunks += len(chunks)
116
- if total_chunks >= num_rows:
117
- break
118
-
119
- dataframe = pd.DataFrame.from_records(
120
- [(k, v) for k, values in data.items() for v in values],
121
- columns=["filename", "chunks"],
122
- )
123
- col_doc = "chunks"
124
-
125
- return (
126
- dataframe,
127
- gr.Dropdown(
128
- choices=["chunks"],
129
- label="Documents column",
130
- value=col_doc,
131
- interactive=(False if col_doc == "" else True),
132
- multiselect=False,
133
- ),
134
- )
135
-
136
 
137
  def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
138
  progress(0.1, desc="Initializing")
@@ -161,9 +79,48 @@ def load_dataset_file(
161
  ):
162
  progress(0.1, desc="Loading the source data")
163
  if input_type == "dataset-input":
164
- return _load_dataset_from_hub(repo_id, num_rows, token)
165
  else:
166
- return _preprocess_input_data(file_paths, num_rows)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
 
169
  def generate_dataset(
@@ -323,44 +280,6 @@ def generate_dataset(
323
  return dataframe
324
 
325
 
326
- def generate_sample_dataset(
327
- repo_id: str,
328
- file_paths: list[str],
329
- input_type: str,
330
- system_prompt: str,
331
- document_column: str,
332
- retrieval_reranking: list[str],
333
- num_rows: str,
334
- oauth_token: Union[OAuthToken, None],
335
- progress=gr.Progress(),
336
- ):
337
- retrieval = "Retrieval" in retrieval_reranking
338
- reranking = "Reranking" in retrieval_reranking
339
-
340
- if input_type == "prompt-input":
341
- dataframe = pd.DataFrame(columns=["context", "question", "response"])
342
- else:
343
- dataframe, _ = load_dataset_file(
344
- repo_id=repo_id,
345
- file_paths=file_paths,
346
- input_type=input_type,
347
- num_rows=num_rows,
348
- token=oauth_token,
349
- )
350
- progress(0.5, desc="Generating dataset")
351
- dataframe = generate_dataset(
352
- input_type=input_type,
353
- dataframe=dataframe,
354
- system_prompt=system_prompt,
355
- document_column=document_column,
356
- retrieval=retrieval,
357
- reranking=reranking,
358
- num_rows=10,
359
- is_sample=True,
360
- )
361
- return dataframe
362
-
363
-
364
  def push_dataset_to_hub(
365
  dataframe: pd.DataFrame,
366
  org_name: str,
@@ -428,15 +347,12 @@ def push_dataset(
428
  reranking=reranking,
429
  num_rows=num_rows,
430
  temperature=temperature,
431
- is_sample=True,
432
  )
433
  push_dataset_to_hub(
434
  dataframe, org_name, repo_name, oauth_token, private, pipeline_code
435
  )
436
  dataframe = dataframe[
437
- dataframe.applymap(
438
- lambda x: str(x).strip() if pd.notna(x) else x
439
- ).apply(
440
  lambda row: row.notna().all() and (row != "").all(), axis=1
441
  )
442
  ]
@@ -677,7 +593,7 @@ with gr.Blocks() as app:
677
 
678
  gr.HTML(value="<hr>")
679
  gr.Markdown(value="## 2. Configure your task")
680
- with gr.Row(equal_height=True):
681
  with gr.Column(scale=2):
682
  system_prompt = gr.Textbox(
683
  label="System prompt",
@@ -701,9 +617,7 @@ with gr.Blocks() as app:
701
  )
702
  with gr.Row():
703
  clear_btn_full = gr.Button("Clear", variant="secondary")
704
- btn_apply_to_sample_dataset = gr.Button(
705
- "Save", variant="primary"
706
- )
707
  with gr.Column(scale=3):
708
  dataframe = gr.Dataframe(
709
  headers=["context", "question", "response"],
@@ -791,35 +705,23 @@ with gr.Blocks() as app:
791
  fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
792
  )
793
 
794
- search_in.submit(fn=get_iframe, inputs=search_in, outputs=search_out).then(
795
  fn=lambda df: pd.DataFrame(columns=df.columns),
796
  inputs=[dataframe],
797
  outputs=[dataframe],
798
  )
799
 
800
- load_dataset_btn.click(
 
801
  fn=load_dataset_file,
802
  inputs=[search_in, file_in, input_type],
803
- outputs=[
804
- dataframe,
805
- document_column,
806
- ],
807
- )
808
-
809
- load_file_btn.click(
810
- fn=load_dataset_file,
811
- inputs=[search_in, file_in, input_type],
812
- outputs=[
813
- dataframe,
814
- document_column,
815
- ],
816
  )
817
 
818
  load_prompt_btn.click(
819
  fn=generate_system_prompt,
820
  inputs=[dataset_description],
821
  outputs=[system_prompt],
822
- show_progress=True,
823
  ).success(
824
  fn=generate_sample_dataset,
825
  inputs=[
@@ -852,16 +754,13 @@ with gr.Blocks() as app:
852
  fn=validate_argilla_user_workspace_dataset,
853
  inputs=[repo_name],
854
  outputs=[success_message],
855
- show_progress=True,
856
  ).then(
857
  fn=validate_push_to_hub,
858
  inputs=[org_name, repo_name],
859
  outputs=[success_message],
860
- show_progress=True,
861
  ).success(
862
  fn=hide_success_message,
863
  outputs=[success_message],
864
- show_progress=True,
865
  ).success(
866
  fn=hide_pipeline_code_visibility,
867
  inputs=[],
@@ -883,7 +782,6 @@ with gr.Blocks() as app:
883
  pipeline_code,
884
  ],
885
  outputs=[success_message],
886
- show_progress=True,
887
  ).success(
888
  fn=show_success_message,
889
  inputs=[org_name, repo_name],
@@ -905,11 +803,9 @@ with gr.Blocks() as app:
905
  outputs=[pipeline_code_ui],
906
  )
907
 
908
- clear_dataset_btn_part.click(fn=lambda : "", inputs=[], outputs=[search_in])
909
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
910
- clear_prompt_btn_part.click(
911
- fn=lambda : "", inputs=[], outputs=[dataset_description]
912
- )
913
  clear_btn_full.click(
914
  fn=lambda df: ("", [], pd.DataFrame(columns=df.columns)),
915
  inputs=[dataframe],
 
1
  import os
2
  import random
3
  import uuid
 
4
  from typing import Union
5
 
6
  import argilla as rg
7
  import gradio as gr
8
  import nltk
9
  import pandas as pd
10
+ from datasets import Dataset
 
 
 
 
 
11
  from distilabel.distiset import Distiset
12
  from gradio.oauth import OAuthToken
13
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
  from huggingface_hub import HfApi
 
 
15
 
16
  from synthetic_dataset_generator.apps.base import (
17
  combine_datasets,
 
18
  hide_success_message,
19
+ load_dataset_from_hub,
20
+ preprocess_input_data,
21
  push_pipeline_code_to_hub,
22
  show_success_message,
23
  test_max_num_rows,
 
32
  )
33
  from synthetic_dataset_generator.pipelines.rag import (
34
  DEFAULT_DATASET_DESCRIPTIONS,
35
+ generate_pipeline_code,
36
  get_chunks_generator,
37
  get_prompt_generator,
 
 
38
  get_response_generator,
39
+ get_sentence_pair_generator,
40
  )
41
  from synthetic_dataset_generator.utils import (
42
  column_to_list,
 
51
  nltk.download("punkt_tab", download_dir="./nltk_data")
52
  nltk.download("averaged_perceptron_tagger_eng", download_dir="./nltk_data")
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
56
  progress(0.1, desc="Initializing")
 
79
  ):
80
  progress(0.1, desc="Loading the source data")
81
  if input_type == "dataset-input":
82
+ return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token)
83
  else:
84
+ return preprocess_input_data(file_paths=file_paths, num_rows=num_rows)
85
+
86
+
87
+ def generate_sample_dataset(
88
+ repo_id: str,
89
+ file_paths: list[str],
90
+ input_type: str,
91
+ system_prompt: str,
92
+ document_column: str,
93
+ retrieval_reranking: list[str],
94
+ num_rows: str,
95
+ oauth_token: Union[OAuthToken, None],
96
+ progress=gr.Progress(),
97
+ ):
98
+ retrieval = "Retrieval" in retrieval_reranking
99
+ reranking = "Reranking" in retrieval_reranking
100
+
101
+ if input_type == "prompt-input":
102
+ dataframe = pd.DataFrame(columns=["context", "question", "response"])
103
+ else:
104
+ dataframe, _ = load_dataset_file(
105
+ repo_id=repo_id,
106
+ file_paths=file_paths,
107
+ input_type=input_type,
108
+ num_rows=num_rows,
109
+ token=oauth_token,
110
+ )
111
+ progress(0.5, desc="Generating dataset")
112
+ dataframe = generate_dataset(
113
+ input_type=input_type,
114
+ dataframe=dataframe,
115
+ system_prompt=system_prompt,
116
+ document_column=document_column,
117
+ retrieval=retrieval,
118
+ reranking=reranking,
119
+ num_rows=10,
120
+ is_sample=True,
121
+ )
122
+ progress(1.0, desc="Sample dataset generated")
123
+ return dataframe
124
 
125
 
126
  def generate_dataset(
 
280
  return dataframe
281
 
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  def push_dataset_to_hub(
284
  dataframe: pd.DataFrame,
285
  org_name: str,
 
347
  reranking=reranking,
348
  num_rows=num_rows,
349
  temperature=temperature,
 
350
  )
351
  push_dataset_to_hub(
352
  dataframe, org_name, repo_name, oauth_token, private, pipeline_code
353
  )
354
  dataframe = dataframe[
355
+ dataframe.applymap(lambda x: str(x).strip() if pd.notna(x) else x).apply(
 
 
356
  lambda row: row.notna().all() and (row != "").all(), axis=1
357
  )
358
  ]
 
593
 
594
  gr.HTML(value="<hr>")
595
  gr.Markdown(value="## 2. Configure your task")
596
+ with gr.Row(equal_height=False):
597
  with gr.Column(scale=2):
598
  system_prompt = gr.Textbox(
599
  label="System prompt",
 
617
  )
618
  with gr.Row():
619
  clear_btn_full = gr.Button("Clear", variant="secondary")
620
+ btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
 
 
621
  with gr.Column(scale=3):
622
  dataframe = gr.Dataframe(
623
  headers=["context", "question", "response"],
 
705
  fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
706
  )
707
 
708
+ search_in.submit(
709
  fn=lambda df: pd.DataFrame(columns=df.columns),
710
  inputs=[dataframe],
711
  outputs=[dataframe],
712
  )
713
 
714
+ gr.on(
715
+ triggers=[load_dataset_btn.click, load_file_btn.click],
716
  fn=load_dataset_file,
717
  inputs=[search_in, file_in, input_type],
718
+ outputs=[dataframe, document_column],
 
 
 
 
 
 
 
 
 
 
 
 
719
  )
720
 
721
  load_prompt_btn.click(
722
  fn=generate_system_prompt,
723
  inputs=[dataset_description],
724
  outputs=[system_prompt],
 
725
  ).success(
726
  fn=generate_sample_dataset,
727
  inputs=[
 
754
  fn=validate_argilla_user_workspace_dataset,
755
  inputs=[repo_name],
756
  outputs=[success_message],
 
757
  ).then(
758
  fn=validate_push_to_hub,
759
  inputs=[org_name, repo_name],
760
  outputs=[success_message],
 
761
  ).success(
762
  fn=hide_success_message,
763
  outputs=[success_message],
 
764
  ).success(
765
  fn=hide_pipeline_code_visibility,
766
  inputs=[],
 
782
  pipeline_code,
783
  ],
784
  outputs=[success_message],
 
785
  ).success(
786
  fn=show_success_message,
787
  inputs=[org_name, repo_name],
 
803
  outputs=[pipeline_code_ui],
804
  )
805
 
806
+ clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
807
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
808
+ clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
 
 
809
  clear_btn_full.click(
810
  fn=lambda df: ("", [], pd.DataFrame(columns=df.columns)),
811
  inputs=[dataframe],
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -458,7 +458,7 @@ with gr.Blocks() as app:
458
 
459
  gr.HTML("<hr>")
460
  gr.Markdown("## 2. Configure your dataset")
461
- with gr.Row(equal_height=True):
462
  with gr.Column(scale=2):
463
  system_prompt = gr.Textbox(
464
  label="System prompt",
@@ -508,9 +508,7 @@ with gr.Blocks() as app:
508
  )
509
  with gr.Row():
510
  clear_btn_full = gr.Button("Clear", variant="secondary")
511
- btn_apply_to_sample_dataset = gr.Button(
512
- "Save", variant="primary"
513
- )
514
  with gr.Column(scale=3):
515
  dataframe = _get_dataframe()
516
 
@@ -574,45 +572,37 @@ with gr.Blocks() as app:
574
  fn=generate_system_prompt,
575
  inputs=[dataset_description],
576
  outputs=[system_prompt, labels],
577
- show_progress=True,
578
  ).then(
579
  fn=generate_sample_dataset,
580
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
581
  outputs=[dataframe],
582
- show_progress=True,
583
  )
584
 
585
  btn_apply_to_sample_dataset.click(
586
  fn=validate_input_labels,
587
  inputs=[labels],
588
  outputs=[labels],
589
- show_progress=True,
590
  ).success(
591
  fn=generate_sample_dataset,
592
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
593
  outputs=[dataframe],
594
- show_progress=True,
595
  )
596
 
597
  btn_push_to_hub.click(
598
  fn=validate_argilla_user_workspace_dataset,
599
  inputs=[repo_name],
600
  outputs=[success_message],
601
- show_progress=True,
602
  ).then(
603
  fn=validate_push_to_hub,
604
  inputs=[org_name, repo_name],
605
  outputs=[success_message],
606
- show_progress=True,
607
  ).success(
608
  fn=validate_input_labels,
609
  inputs=[labels],
610
  outputs=[labels],
611
- show_progress=True,
612
  ).success(
613
  fn=hide_success_message,
614
  outputs=[success_message],
615
- show_progress=True,
616
  ).success(
617
  fn=hide_pipeline_code_visibility,
618
  inputs=[],
@@ -633,7 +623,6 @@ with gr.Blocks() as app:
633
  pipeline_code,
634
  ],
635
  outputs=[success_message],
636
- show_progress=True,
637
  ).success(
638
  fn=show_success_message,
639
  inputs=[org_name, repo_name],
 
458
 
459
  gr.HTML("<hr>")
460
  gr.Markdown("## 2. Configure your dataset")
461
+ with gr.Row(equal_height=False):
462
  with gr.Column(scale=2):
463
  system_prompt = gr.Textbox(
464
  label="System prompt",
 
508
  )
509
  with gr.Row():
510
  clear_btn_full = gr.Button("Clear", variant="secondary")
511
+ btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
 
 
512
  with gr.Column(scale=3):
513
  dataframe = _get_dataframe()
514
 
 
572
  fn=generate_system_prompt,
573
  inputs=[dataset_description],
574
  outputs=[system_prompt, labels],
 
575
  ).then(
576
  fn=generate_sample_dataset,
577
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
578
  outputs=[dataframe],
 
579
  )
580
 
581
  btn_apply_to_sample_dataset.click(
582
  fn=validate_input_labels,
583
  inputs=[labels],
584
  outputs=[labels],
 
585
  ).success(
586
  fn=generate_sample_dataset,
587
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
588
  outputs=[dataframe],
 
589
  )
590
 
591
  btn_push_to_hub.click(
592
  fn=validate_argilla_user_workspace_dataset,
593
  inputs=[repo_name],
594
  outputs=[success_message],
 
595
  ).then(
596
  fn=validate_push_to_hub,
597
  inputs=[org_name, repo_name],
598
  outputs=[success_message],
 
599
  ).success(
600
  fn=validate_input_labels,
601
  inputs=[labels],
602
  outputs=[labels],
 
603
  ).success(
604
  fn=hide_success_message,
605
  outputs=[success_message],
 
606
  ).success(
607
  fn=hide_pipeline_code_visibility,
608
  inputs=[],
 
623
  pipeline_code,
624
  ],
625
  outputs=[success_message],
 
626
  ).success(
627
  fn=show_success_message,
628
  inputs=[org_name, repo_name],
src/synthetic_dataset_generator/pipelines/chat.py CHANGED
@@ -1,4 +1,10 @@
1
- from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
 
 
 
 
 
 
2
 
3
  from synthetic_dataset_generator.constants import (
4
  MAGPIE_PRE_QUERY_TEMPLATE,
@@ -118,6 +124,18 @@ The prompt you write should follow the same style and structure as the following
118
  User dataset description:
119
  """
120
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  DEFAULT_DATASET_DESCRIPTIONS = [
122
  "rude customer assistant for a phone company",
123
  "assistant that solves math puzzles using python",
@@ -203,6 +221,21 @@ def get_magpie_generator(num_turns: int, temperature: float, is_sample: bool):
203
  return magpie_generator
204
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def get_response_generator(
207
  system_prompt: str, num_turns: int, temperature: float, is_sample: bool
208
  ):
@@ -231,36 +264,236 @@ def get_response_generator(
231
  return response_generator
232
 
233
 
234
- def generate_pipeline_code(system_prompt: str, num_turns: int, num_rows: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  input_mappings = _get_output_mappings(num_turns)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
 
 
 
 
 
 
 
 
 
237
  code = f"""
238
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
239
- import os
240
  from distilabel.pipeline import Pipeline
241
- from distilabel.steps import KeepColumns
242
- from distilabel.steps.tasks import MagpieGenerator
243
- from distilabel.llms import {_get_llm_class()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
- SYSTEM_PROMPT = "{system_prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  with Pipeline(name="sft") as pipeline:
248
- magpie = MagpieGenerator(
 
 
 
 
 
 
 
 
 
249
  llm={_get_llm_class()}.from_dict(
250
  {_get_llm().dump()}
251
  ),
252
- n_turns={num_turns},
253
- num_rows={num_rows},
254
- batch_size=1,
255
- system_prompt=SYSTEM_PROMPT,
256
- output_mappings={input_mappings},
257
  )
258
- keep_columns = KeepColumns(
259
- columns={list(input_mappings.values())} + ["model_name"],
 
 
 
 
 
 
 
260
  )
261
- magpie.connect(keep_columns)
 
 
 
 
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  if __name__ == "__main__":
264
  distiset = pipeline.run()
 
265
  """
266
  return code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import get_dataset_config_names, get_dataset_split_names
2
+ from distilabel.steps.tasks import (
3
+ ChatGeneration,
4
+ Magpie,
5
+ GenerateSentencePair,
6
+ TextGeneration,
7
+ )
8
 
9
  from synthetic_dataset_generator.constants import (
10
  MAGPIE_PRE_QUERY_TEMPLATE,
 
124
  User dataset description:
125
  """
126
 
127
+ FOLLOW_UP_TEMPLATE = """Conversation:
128
+ {% for message in messages %}
129
+ {% if message.role == "user" %}
130
+ User Question: {{ message.content }}
131
+ {% elif message.role == "assistant" %}
132
+ Assistant Response: {{ message.content }}
133
+ {% endif %}
134
+ {% endfor %}
135
+
136
+ Please generate the next logical user message in this conversation. Do not include any other information or 'User Question' in your response.
137
+ """.rstrip()
138
+
139
  DEFAULT_DATASET_DESCRIPTIONS = [
140
  "rude customer assistant for a phone company",
141
  "assistant that solves math puzzles using python",
 
221
  return magpie_generator
222
 
223
 
224
+ def get_sentence_pair_generator(temperature: float, is_sample: bool):
225
+ generation_kwargs = {
226
+ "temperature": temperature,
227
+ "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
228
+ }
229
+ sentence_pair_generator = GenerateSentencePair(
230
+ llm=_get_llm(generation_kwargs=generation_kwargs),
231
+ triplet=False,
232
+ action="query",
233
+ hard_negative=True,
234
+ )
235
+ sentence_pair_generator.load()
236
+ return sentence_pair_generator
237
+
238
+
239
  def get_response_generator(
240
  system_prompt: str, num_turns: int, temperature: float, is_sample: bool
241
  ):
 
264
  return response_generator
265
 
266
 
267
+ def get_follow_up_generator(type: str, temperature: float, is_sample: bool):
268
+ if type == "instruction":
269
+ generation_kwargs = {
270
+ "temperature": temperature,
271
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
272
+ }
273
+ follow_up_generator = TextGeneration(
274
+ llm=_get_llm(generation_kwargs=generation_kwargs),
275
+ template=FOLLOW_UP_TEMPLATE,
276
+ columns=["messages"],
277
+ )
278
+ else:
279
+ generation_kwargs = {
280
+ "temperature": temperature,
281
+ "max_new_tokens": MAX_NUM_TOKENS,
282
+ }
283
+ follow_up_generator = ChatGeneration(
284
+ llm=_get_llm(generation_kwargs=generation_kwargs),
285
+ )
286
+ follow_up_generator.load()
287
+ return follow_up_generator
288
+
289
+ def generate_pipeline_code_system_prompt(
290
+ system_prompt: str,
291
+ num_turns: int,
292
+ num_rows: int,
293
+ ):
294
  input_mappings = _get_output_mappings(num_turns)
295
+ code = f"""
296
+ # Requirements: `pip install distilabel[hf-inference-endpoints]`
297
+ import os
298
+ from distilabel.pipeline import Pipeline
299
+ from distilabel.steps import KeepColumns
300
+ from distilabel.steps.tasks import MagpieGenerator
301
+ from distilabel.llms import {_get_llm_class()}
302
+
303
+ SYSTEM_PROMPT = "{system_prompt}"
304
+
305
+ with Pipeline(name="sft") as pipeline:
306
+ magpie = MagpieGenerator(
307
+ llm={_get_llm_class()}.from_dict(
308
+ {_get_llm().dump()}
309
+ ),
310
+ n_turns={num_turns},
311
+ num_rows={num_rows},
312
+ batch_size=1,
313
+ system_prompt=SYSTEM_PROMPT,
314
+ output_mappings={input_mappings},
315
+ )
316
+ keep_columns = KeepColumns(
317
+ columns={list(input_mappings.values())} + ["model_name"],
318
+ )
319
+ magpie.connect(keep_columns)
320
+
321
+ if __name__ == "__main__":
322
+ distiset = pipeline.run()
323
+ """
324
+ return code
325
 
326
+ def generate_pipeline_code_seed(
327
+ repo_id: str,
328
+ subset: str,
329
+ split: str,
330
+ input_type: str,
331
+ document_column: str,
332
+ num_turns: int,
333
+ num_rows: int,
334
+ ):
335
  code = f"""
336
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
337
+ from distilabel.models import {_get_llm_class()}
338
  from distilabel.pipeline import Pipeline
339
+ from distilabel.steps import KeepColumns{", LoadDataFromDicts" if input_type != "dataset-input" else ""}{", LoadDataFromHub" if input_type == "dataset-input" else ""}
340
+ from distilabel.steps.tasks import GenerateSentencePair, TextGeneration {", ChatGeneration" if num_turns > 1 else ""}
341
+ """
342
+
343
+ if num_turns > 1:
344
+ code += """
345
+ FOLLOW_UP_TEMPLATE = '''Conversation:
346
+ {{% for message in messages %}}
347
+ {{% if message.role == "user" %}}
348
+ User Question: {{{{ message.content }}}}
349
+ {{% elif message.role == "assistant" %}}
350
+ Assistant Response: {{{{ message.content }}}}
351
+ {{% endif %}}
352
+ {{% endfor %}}
353
+
354
+ Please generate the next logical user message in this conversation. Do not include any other information or 'User Question' in your response.
355
+ '''.rstrip()
356
+
357
+ @step(inputs=["prompt", "completion"], outputs=["messages"])
358
+ def PrepareMessages(*inputs: StepInput) -> StepOutput:
359
+ for input in inputs:
360
+ for item in input:
361
+ item["messages"] = [
362
+ {"role": "user", "content": item["prompt"]},
363
+ {"role": "assistant", "content": item["completion"]},
364
+ ]
365
+ yield input
366
+
367
+
368
+ @step(inputs=["messages", "generation"], outputs=["messages"])
369
+ def FormatMessagesInstruction(*inputs: StepInput) -> StepOutput:
370
+ for input in inputs:
371
+ for item in input:
372
+ item["messages"].append({"role": "user", "content": item["generation"]})
373
+ yield input
374
+
375
+
376
+ @step(inputs=["messages", "generation"], outputs=["messages"])
377
+ def FormatMessagesResponse(*inputs: StepInput) -> StepOutput:
378
+ for input in inputs:
379
+ for item in input:
380
+ item["messages"].append({"role": "assistant", "content": item["generation"]})
381
+ yield input
382
+ """
383
 
384
+ if input_type == "dataset-input":
385
+ code += f"""
386
+ with Pipeline(name="sft") as pipeline:
387
+ load_the_dataset = LoadDataFromHub(
388
+ repo_id='{repo_id}',
389
+ config='{subset}',
390
+ split='{split}',
391
+ num_examples={num_rows},
392
+ batch_size=2,
393
+ output_mappings={{'{document_column}':'anchor'}},
394
+ )
395
+ """
396
+
397
+ else:
398
+ code += """
399
+ data = process_and_chunk_files(files=[files])
400
 
401
  with Pipeline(name="sft") as pipeline:
402
+ load_the_dataset = LoadDataFromDicts(
403
+ data = data
404
+ )
405
+ """
406
+ code += f"""
407
+ instruction_generator = GenerateSentencePair(
408
+ name="instruction_generation",
409
+ triplet=False,
410
+ hard_negative=True,
411
+ action="query",
412
  llm={_get_llm_class()}.from_dict(
413
  {_get_llm().dump()}
414
  ),
415
+ input_batch_size=10,
416
+ output_mappings={{"positive": "prompt"}},
 
 
 
417
  )
418
+
419
+ response_generator = TextGeneration(
420
+ name="response_generation",
421
+ llm={_get_llm_class()}.from_dict(
422
+ {_get_llm().dump()}
423
+ ),
424
+ input_batch_size=10,
425
+ input_mappings={{"instruction": "prompt"}},
426
+ output_mappings={{"generation": "completion"}},
427
  )
428
+ """
429
+
430
+ if num_turns > 1:
431
+ code += """
432
+ prepare_messages = PrepareMessages()
433
+ """
434
 
435
+ for i in range(num_turns - 1):
436
+ code += f"""
437
+ follow_up_instruction_{i} = TextGeneration(
438
+ llm={_get_llm_class()}.from_dict(
439
+ {_get_llm().dump()}
440
+ ),
441
+ template=FOLLOW_UP_TEMPLATE,
442
+ columns=["messages"],
443
+ )
444
+ format_instruction_{i} = FormatMessagesInstruction()
445
+ follow_up_response_{i} = ChatGeneration(
446
+ llm={_get_llm_class()}.from_dict(
447
+ {_get_llm().dump()}
448
+ ),
449
+ )
450
+ format_response_{i} = FormatMessagesResponse()
451
+ """
452
+
453
+ if num_turns > 1:
454
+ code += """
455
+ keep_columns = KeepColumns(columns=["messages"])
456
+ """
457
+ code += "load_the_dataset >> instruction_generator >> response_generator >> prepare_messages"
458
+
459
+ for i in range(1, num_turns + 1):
460
+ code += f" >> follow_up_instruction_{i} >> format_instruction_{i} >> follow_up_response_{i} >> format_response_{i}"
461
+
462
+ code += " >> keep_columns"
463
+
464
+ code += """
465
  if __name__ == "__main__":
466
  distiset = pipeline.run()
467
+ )
468
  """
469
  return code
470
+
471
+ def generate_pipeline_code(
472
+ repo_id: str,
473
+ input_type: str,
474
+ system_prompt: str,
475
+ document_column: str,
476
+ num_turns: int,
477
+ num_rows: int,
478
+ ):
479
+ if input_type == "dataset-input" and repo_id is not None:
480
+ subset = get_dataset_config_names(repo_id)[0]
481
+ split = get_dataset_split_names(repo_id, subset)[0]
482
+ else:
483
+ subset = "default"
484
+ split = "train"
485
+ if input_type == "prompt-type":
486
+ return generate_pipeline_code_system_prompt(
487
+ system_prompt=system_prompt,
488
+ num_turns=num_turns,
489
+ num_rows=num_rows,
490
+ )
491
+ return generate_pipeline_code_seed(
492
+ repo_id=repo_id,
493
+ subset=subset,
494
+ split=split,
495
+ input_type=input_type,
496
+ document_column=document_column,
497
+ num_turns=num_turns,
498
+ num_rows=num_rows,
499
+ )
src/synthetic_dataset_generator/pipelines/rag.py CHANGED
@@ -1,7 +1,3 @@
1
- import os
2
-
3
- from typing import List
4
-
5
  from datasets import get_dataset_config_names, get_dataset_split_names
6
  from distilabel.steps.tasks import (
7
  GenerateSentencePair,
@@ -292,10 +288,7 @@ with Pipeline(name="rag") as pipeline:
292
 
293
  pipeline += """
294
  if __name__ == "__main__":
295
- distiset = pipeline.run(use_cache=False)
296
- print(distiset)
297
- if distiset:
298
- print(distiset["default"]["train"][0])
299
  """
300
 
301
  return base_code + pipeline
 
 
 
 
 
1
  from datasets import get_dataset_config_names, get_dataset_split_names
2
  from distilabel.steps.tasks import (
3
  GenerateSentencePair,
 
288
 
289
  pipeline += """
290
  if __name__ == "__main__":
291
+ distiset = pipeline.run()
 
 
 
292
  """
293
 
294
  return base_code + pipeline