|
import os |
|
import random |
|
import uuid |
|
from typing import Union |
|
|
|
import argilla as rg |
|
import gradio as gr |
|
import nltk |
|
import pandas as pd |
|
from datasets import Dataset |
|
from distilabel.distiset import Distiset |
|
from gradio.oauth import OAuthToken |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
from huggingface_hub import HfApi |
|
|
|
from synthetic_dataset_generator.apps.base import ( |
|
combine_datasets, |
|
hide_success_message, |
|
load_dataset_from_hub, |
|
preprocess_input_data, |
|
push_pipeline_code_to_hub, |
|
show_success_message, |
|
test_max_num_rows, |
|
validate_argilla_user_workspace_dataset, |
|
validate_push_to_hub, |
|
) |
|
from synthetic_dataset_generator.constants import ( |
|
DEFAULT_BATCH_SIZE, |
|
MODEL, |
|
MODEL_COMPLETION, |
|
SAVE_LOCAL_DIR, |
|
) |
|
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts |
|
from synthetic_dataset_generator.pipelines.embeddings import ( |
|
get_embeddings, |
|
get_sentence_embedding_dimensions, |
|
) |
|
from synthetic_dataset_generator.pipelines.rag import ( |
|
DEFAULT_DATASET_DESCRIPTIONS, |
|
generate_pipeline_code, |
|
get_chunks_generator, |
|
get_prompt_generator, |
|
get_response_generator, |
|
get_sentence_pair_generator, |
|
) |
|
from synthetic_dataset_generator.utils import ( |
|
column_to_list, |
|
get_argilla_client, |
|
get_org_dropdown, |
|
get_random_repo_name, |
|
swap_visibility, |
|
) |
|
|
|
os.makedirs("./nltk_data", exist_ok=True) |
|
nltk.data.path.append("./nltk_data") |
|
nltk.download("punkt_tab", download_dir="./nltk_data") |
|
nltk.download("averaged_perceptron_tagger_eng", download_dir="./nltk_data") |
|
|
|
|
|
def generate_system_prompt(dataset_description: str, progress=gr.Progress()): |
|
progress(0.1, desc="Initializing") |
|
generate_description = get_prompt_generator() |
|
progress(0.5, desc="Generating") |
|
result = next( |
|
generate_description.process( |
|
[ |
|
{ |
|
"instruction": dataset_description, |
|
} |
|
] |
|
) |
|
)[0]["generation"] |
|
progress(1.0, desc="Prompt generated") |
|
return result |
|
|
|
|
|
def load_dataset_file( |
|
repo_id: str, |
|
file_paths: list[str], |
|
input_type: str, |
|
num_rows: int = 10, |
|
token: Union[OAuthToken, None] = None, |
|
progress=gr.Progress(), |
|
): |
|
progress(0.1, desc="Loading the source data") |
|
if input_type == "dataset-input": |
|
return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token) |
|
else: |
|
return preprocess_input_data(file_paths=file_paths, num_rows=num_rows) |
|
|
|
|
|
def generate_sample_dataset( |
|
repo_id: str, |
|
file_paths: list[str], |
|
input_type: str, |
|
system_prompt: str, |
|
document_column: str, |
|
retrieval_reranking: list[str], |
|
num_rows: str, |
|
oauth_token: Union[OAuthToken, None], |
|
progress=gr.Progress(), |
|
): |
|
retrieval = "Retrieval" in retrieval_reranking |
|
reranking = "Reranking" in retrieval_reranking |
|
|
|
if input_type == "prompt-input": |
|
dataframe = pd.DataFrame(columns=["context", "question", "response"]) |
|
else: |
|
dataframe, _ = load_dataset_file( |
|
repo_id=repo_id, |
|
file_paths=file_paths, |
|
input_type=input_type, |
|
num_rows=num_rows, |
|
token=oauth_token, |
|
) |
|
progress(0.5, desc="Generating dataset") |
|
dataframe = generate_dataset( |
|
input_type=input_type, |
|
dataframe=dataframe, |
|
system_prompt=system_prompt, |
|
document_column=document_column, |
|
retrieval=retrieval, |
|
reranking=reranking, |
|
num_rows=10, |
|
is_sample=True, |
|
) |
|
progress(1.0, desc="Sample dataset generated") |
|
return dataframe |
|
|
|
|
|
def generate_dataset( |
|
input_type: str, |
|
dataframe: pd.DataFrame, |
|
system_prompt: str, |
|
document_column: str, |
|
retrieval: bool = False, |
|
reranking: bool = False, |
|
num_rows: int = 10, |
|
temperature: float = 0.7, |
|
temperature_completion: Union[float, None] = None, |
|
is_sample: bool = False, |
|
progress=gr.Progress(), |
|
): |
|
num_rows = test_max_num_rows(num_rows) |
|
progress(0.0, desc="Initializing dataset generation") |
|
if input_type == "prompt-input": |
|
chunk_generator = get_chunks_generator( |
|
temperature=temperature, is_sample=is_sample |
|
) |
|
else: |
|
document_data = column_to_list(dataframe, document_column) |
|
if len(document_data) < num_rows: |
|
document_data += random.choices( |
|
document_data, k=num_rows - len(document_data) |
|
) |
|
|
|
retrieval_generator = get_sentence_pair_generator( |
|
action="query", |
|
triplet=True if retrieval else False, |
|
temperature=temperature, |
|
is_sample=is_sample, |
|
) |
|
response_generator = get_response_generator( |
|
temperature=temperature_completion or temperature, is_sample=is_sample |
|
) |
|
if reranking: |
|
reranking_generator = get_sentence_pair_generator( |
|
action="semantically-similar", |
|
triplet=True, |
|
temperature=temperature, |
|
is_sample=is_sample, |
|
) |
|
steps = 2 + sum([1 if reranking else 0, 1 if input_type == "prompt-type" else 0]) |
|
total_steps: int = num_rows * steps |
|
step_progress = round(1 / steps, 2) |
|
batch_size = DEFAULT_BATCH_SIZE |
|
|
|
|
|
if input_type == "prompt-input": |
|
n_processed = 0 |
|
chunk_results = [] |
|
rewritten_system_prompts = get_rewritten_prompts(system_prompt, num_rows) |
|
while n_processed < num_rows: |
|
progress( |
|
step_progress * n_processed / num_rows, |
|
total=total_steps, |
|
desc="Generating chunks", |
|
) |
|
remaining_rows = num_rows - n_processed |
|
batch_size = min(batch_size, remaining_rows) |
|
inputs = [ |
|
{"task": random.choice(rewritten_system_prompts)} |
|
for _ in range(batch_size) |
|
] |
|
chunks = list(chunk_generator.process(inputs=inputs)) |
|
chunk_results.extend(chunks[0]) |
|
n_processed += batch_size |
|
random.seed(a=random.randint(0, 2**32 - 1)) |
|
document_data = [chunk["generation"] for chunk in chunk_results] |
|
progress(step_progress, desc="Generating chunks") |
|
|
|
|
|
n_processed = 0 |
|
retrieval_results = [] |
|
while n_processed < num_rows: |
|
progress( |
|
step_progress * n_processed / num_rows, |
|
total=total_steps, |
|
desc="Generating questions", |
|
) |
|
remaining_rows = num_rows - n_processed |
|
batch_size = min(batch_size, remaining_rows) |
|
inputs = [ |
|
{"anchor": document} |
|
for document in document_data[n_processed : n_processed + batch_size] |
|
] |
|
questions = list(retrieval_generator.process(inputs=inputs)) |
|
retrieval_results.extend(questions[0]) |
|
n_processed += batch_size |
|
for result in retrieval_results: |
|
result["context"] = result["anchor"] |
|
if retrieval: |
|
result["question"] = result["positive"] |
|
result["positive_retrieval"] = result.pop("positive") |
|
result["negative_retrieval"] = result.pop("negative") |
|
else: |
|
result["question"] = result.pop("positive") |
|
|
|
progress(step_progress, desc="Generating questions") |
|
|
|
|
|
n_processed = 0 |
|
response_results = [] |
|
while n_processed < num_rows: |
|
progress( |
|
step_progress + step_progress * n_processed / num_rows, |
|
total=total_steps, |
|
desc="Generating responses", |
|
) |
|
batch = retrieval_results[n_processed : n_processed + batch_size] |
|
responses = list(response_generator.process(inputs=batch)) |
|
response_results.extend(responses[0]) |
|
n_processed += batch_size |
|
for result in response_results: |
|
result["response"] = result["generation"] |
|
progress(step_progress, desc="Generating responses") |
|
|
|
|
|
if reranking: |
|
n_processed = 0 |
|
reranking_results = [] |
|
while n_processed < num_rows: |
|
progress( |
|
step_progress * n_processed / num_rows, |
|
total=total_steps, |
|
desc="Generating reranking data", |
|
) |
|
batch = response_results[n_processed : n_processed + batch_size] |
|
batch = list(reranking_generator.process(inputs=batch)) |
|
reranking_results.extend(batch[0]) |
|
n_processed += batch_size |
|
for result in reranking_results: |
|
result["positive_reranking"] = result.pop("positive") |
|
result["negative_reranking"] = result.pop("negative") |
|
progress( |
|
1, |
|
total=total_steps, |
|
desc="Creating dataset", |
|
) |
|
|
|
|
|
distiset_results = [] |
|
source_results = reranking_results if reranking else response_results |
|
base_keys = ["context", "question", "response"] |
|
retrieval_keys = ["positive_retrieval", "negative_retrieval"] if retrieval else [] |
|
reranking_keys = ["positive_reranking", "negative_reranking"] if reranking else [] |
|
relevant_keys = base_keys + retrieval_keys + reranking_keys |
|
|
|
for result in source_results: |
|
record = {key: result.get(key) for key in relevant_keys if key in result} |
|
distiset_results.append(record) |
|
|
|
dataframe = pd.DataFrame(distiset_results) |
|
|
|
progress(1.0, desc="Dataset generation completed") |
|
return dataframe |
|
|
|
|
|
def push_dataset_to_hub( |
|
dataframe: pd.DataFrame, |
|
org_name: str, |
|
repo_name: str, |
|
oauth_token: Union[gr.OAuthToken, None], |
|
private: bool, |
|
pipeline_code: str, |
|
progress=gr.Progress(), |
|
): |
|
progress(0.0, desc="Validating") |
|
repo_id = validate_push_to_hub(org_name, repo_name) |
|
progress(0.5, desc="Creating dataset") |
|
dataset = Dataset.from_pandas(dataframe) |
|
dataset = combine_datasets(repo_id, dataset, oauth_token) |
|
distiset = Distiset({"default": dataset}) |
|
progress(0.9, desc="Pushing dataset") |
|
distiset.push_to_hub( |
|
repo_id=repo_id, |
|
private=private, |
|
include_script=False, |
|
token=oauth_token.token, |
|
create_pr=False, |
|
) |
|
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token) |
|
progress(1.0, desc="Dataset pushed") |
|
return dataframe |
|
|
|
|
|
def push_dataset( |
|
org_name: str, |
|
repo_name: str, |
|
private: bool, |
|
original_repo_id: str, |
|
file_paths: list[str], |
|
input_type: str, |
|
system_prompt: str, |
|
document_column: str, |
|
retrieval_reranking: list[str], |
|
num_rows: int, |
|
temperature: float, |
|
temperature_completion: float, |
|
pipeline_code: str, |
|
oauth_token: Union[gr.OAuthToken, None] = None, |
|
progress=gr.Progress(), |
|
) -> pd.DataFrame: |
|
retrieval = "Retrieval" in retrieval_reranking |
|
reranking = "Reranking" in retrieval_reranking |
|
|
|
if input_type == "prompt-input": |
|
dataframe = pd.DataFrame(columns=["context", "question", "response"]) |
|
else: |
|
dataframe, _ = load_dataset_file( |
|
repo_id=original_repo_id, |
|
file_paths=file_paths, |
|
input_type=input_type, |
|
num_rows=num_rows, |
|
token=oauth_token, |
|
) |
|
progress(0.5, desc="Generating dataset") |
|
dataframe = generate_dataset( |
|
input_type=input_type, |
|
dataframe=dataframe, |
|
system_prompt=system_prompt, |
|
document_column=document_column, |
|
retrieval=retrieval, |
|
reranking=reranking, |
|
num_rows=num_rows, |
|
temperature=temperature, |
|
temperature_completion=temperature_completion, |
|
is_sample=True, |
|
) |
|
push_dataset_to_hub( |
|
dataframe, org_name, repo_name, oauth_token, private, pipeline_code |
|
) |
|
dataframe = dataframe[ |
|
dataframe.applymap(lambda x: str(x).strip() if pd.notna(x) else x).apply( |
|
lambda row: row.notna().all() and (row != "").all(), axis=1 |
|
) |
|
] |
|
try: |
|
progress(0.1, desc="Setting up user and workspace") |
|
hf_user = HfApi().whoami(token=oauth_token.token)["name"] |
|
client = get_argilla_client() |
|
if client is None: |
|
return "" |
|
|
|
progress(0.5, desc="Creating dataset in Argilla") |
|
fields = [ |
|
rg.TextField( |
|
name="context", |
|
title="Context", |
|
description="Context for the generation", |
|
), |
|
rg.ChatField( |
|
name="chat", |
|
title="Chat", |
|
description="User and assistant conversation based on the context", |
|
), |
|
] |
|
for item in ["positive", "negative"]: |
|
if retrieval: |
|
fields.append( |
|
rg.TextField( |
|
name=f"{item}_retrieval", |
|
title=f"{item.capitalize()} retrieval", |
|
description=f"The {item} query for retrieval", |
|
) |
|
) |
|
if reranking: |
|
fields.append( |
|
rg.TextField( |
|
name=f"{item}_reranking", |
|
title=f"{item.capitalize()} reranking", |
|
description=f"The {item} query for reranking", |
|
) |
|
) |
|
|
|
questions = [ |
|
rg.LabelQuestion( |
|
name="relevant", |
|
title="Are the question and response relevant to the given context?", |
|
labels=["yes", "no"], |
|
), |
|
rg.LabelQuestion( |
|
name="is_response_correct", |
|
title="Is the response correct?", |
|
labels=["yes", "no"], |
|
), |
|
] |
|
for item in ["positive", "negative"]: |
|
if retrieval: |
|
questions.append( |
|
rg.LabelQuestion( |
|
name=f"is_{item}_retrieval_relevant", |
|
title=f"Is the {item} retrieval relevant?", |
|
labels=["yes", "no"], |
|
required=False, |
|
) |
|
) |
|
if reranking: |
|
questions.append( |
|
rg.LabelQuestion( |
|
name=f"is_{item}_reranking_relevant", |
|
title=f"Is the {item} reranking relevant?", |
|
labels=["yes", "no"], |
|
required=False, |
|
) |
|
) |
|
metadata = [ |
|
rg.IntegerMetadataProperty( |
|
name=f"{item}_length", title=f"{item.capitalize()} length" |
|
) |
|
for item in ["context", "question", "response"] |
|
] |
|
|
|
vectors = [ |
|
rg.VectorField( |
|
name=f"{item}_embeddings", |
|
dimensions=get_sentence_embedding_dimensions(), |
|
) |
|
for item in ["context", "question", "response"] |
|
] |
|
settings = rg.Settings( |
|
fields=fields, |
|
questions=questions, |
|
metadata=metadata, |
|
vectors=vectors, |
|
guidelines="Please review the conversation and provide an evaluation.", |
|
) |
|
|
|
dataframe["chat"] = dataframe.apply( |
|
lambda row: [ |
|
{"role": "user", "content": row["question"]}, |
|
{"role": "assistant", "content": row["response"]}, |
|
], |
|
axis=1, |
|
) |
|
|
|
for item in ["context", "question", "response"]: |
|
dataframe[f"{item}_length"] = dataframe[item].apply( |
|
lambda x: len(x) if x is not None else 0 |
|
) |
|
dataframe[f"{item}_embeddings"] = get_embeddings( |
|
dataframe[item].apply(lambda x: x if x is not None else "").to_list() |
|
) |
|
|
|
rg_dataset = client.datasets(name=repo_name, workspace=hf_user) |
|
if rg_dataset is None: |
|
rg_dataset = rg.Dataset( |
|
name=repo_name, |
|
workspace=hf_user, |
|
settings=settings, |
|
client=client, |
|
) |
|
rg_dataset = rg_dataset.create() |
|
|
|
progress(0.7, desc="Pushing dataset to Argilla") |
|
hf_dataset = Dataset.from_pandas(dataframe) |
|
rg_dataset.records.log(records=hf_dataset) |
|
progress(1.0, desc="Dataset pushed to Argilla") |
|
except Exception as e: |
|
raise gr.Error(f"Error pushing dataset to Argilla: {e}") |
|
return "" |
|
|
|
|
|
def save_local( |
|
repo_id: str, |
|
file_paths: list[str], |
|
input_type: str, |
|
system_prompt: str, |
|
document_column: str, |
|
retrieval_reranking: list[str], |
|
num_rows: int, |
|
temperature: float, |
|
repo_name: str, |
|
temperature_completion: float, |
|
) -> pd.DataFrame: |
|
retrieval = "Retrieval" in retrieval_reranking |
|
reranking = "Reranking" in retrieval_reranking |
|
|
|
if input_type == "prompt-input": |
|
dataframe = pd.DataFrame(columns=["context", "question", "response"]) |
|
else: |
|
dataframe, _ = load_dataset_file( |
|
repo_id=repo_id, |
|
file_paths=file_paths, |
|
input_type=input_type, |
|
num_rows=num_rows, |
|
) |
|
dataframe = generate_dataset( |
|
input_type=input_type, |
|
dataframe=dataframe, |
|
system_prompt=system_prompt, |
|
document_column=document_column, |
|
retrieval=retrieval, |
|
reranking=reranking, |
|
num_rows=num_rows, |
|
temperature=temperature, |
|
temperature_completion=temperature_completion, |
|
) |
|
local_dataset = Dataset.from_pandas(dataframe) |
|
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv") |
|
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json") |
|
local_dataset.to_csv(output_csv, index=False) |
|
local_dataset.to_json(output_json, index=False) |
|
return output_csv, output_json |
|
|
|
|
|
def show_system_prompt_visibility(): |
|
return {system_prompt: gr.Textbox(visible=True)} |
|
|
|
|
|
def hide_system_prompt_visibility(): |
|
return {system_prompt: gr.Textbox(visible=False)} |
|
|
|
|
|
def show_document_column_visibility(): |
|
return {document_column: gr.Dropdown(visible=True)} |
|
|
|
|
|
def hide_document_column_visibility(): |
|
return { |
|
document_column: gr.Dropdown( |
|
choices=["Load your data first in step 1."], |
|
value="Load your data first in step 1.", |
|
visible=False, |
|
) |
|
} |
|
|
|
|
|
def show_pipeline_code_visibility(): |
|
return {pipeline_code_ui: gr.Accordion(visible=True)} |
|
|
|
|
|
def hide_pipeline_code_visibility(): |
|
return {pipeline_code_ui: gr.Accordion(visible=False)} |
|
|
|
|
|
def show_temperature_completion(): |
|
if MODEL != MODEL_COMPLETION: |
|
return {temperature_completion: gr.Slider(value=0.9, visible=True)} |
|
|
|
|
|
def show_save_local_button(): |
|
return {btn_save_local: gr.Button(visible=True)} |
|
|
|
|
|
def hide_save_local_button(): |
|
return {btn_save_local: gr.Button(visible=False)} |
|
|
|
|
|
def show_save_local(): |
|
gr.update(success_message, min_height=0) |
|
return { |
|
csv_file: gr.File(visible=True), |
|
json_file: gr.File(visible=True), |
|
success_message: success_message, |
|
} |
|
|
|
|
|
def hide_save_local(): |
|
gr.update(success_message, min_height=100) |
|
return { |
|
csv_file: gr.File(visible=False), |
|
json_file: gr.File(visible=False), |
|
success_message: success_message, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
with gr.Column() as main_ui: |
|
gr.Markdown("## 1. Select your input") |
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=2): |
|
input_type = gr.Dropdown( |
|
label="Input type", |
|
choices=["dataset-input", "file-input", "prompt-input"], |
|
value="dataset-input", |
|
multiselect=False, |
|
visible=False, |
|
) |
|
with gr.Tab("Load from Hub") as tab_dataset_input: |
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=2): |
|
search_in = HuggingfaceHubSearch( |
|
label="Search", |
|
placeholder="Search for a dataset", |
|
search_type="dataset", |
|
sumbit_on_select=True, |
|
) |
|
with gr.Row(): |
|
clear_dataset_btn_part = gr.Button( |
|
"Clear", variant="secondary" |
|
) |
|
load_dataset_btn = gr.Button("Load", variant="primary") |
|
with gr.Column(scale=3): |
|
examples = gr.Examples( |
|
examples=[ |
|
"charris/wikipedia_sample", |
|
"plaguss/argilla_sdk_docs_raw_unstructured", |
|
"BeIR/hotpotqa-generated-queries", |
|
], |
|
label="Example datasets", |
|
fn=lambda x: x, |
|
inputs=[search_in], |
|
run_on_click=True, |
|
) |
|
search_out = gr.HTML(label="Dataset preview", visible=False) |
|
with gr.Tab("Load your file") as tab_file_input: |
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=2): |
|
file_in = gr.File( |
|
label="Upload your file. Supported formats: .md, .txt, .docx, .pdf", |
|
file_count="multiple", |
|
file_types=[".md", ".txt", ".docx", ".pdf"], |
|
) |
|
with gr.Row(): |
|
clear_file_btn_part = gr.Button( |
|
"Clear", variant="secondary" |
|
) |
|
load_file_btn = gr.Button("Load", variant="primary") |
|
with gr.Column(scale=3): |
|
file_out = gr.HTML(label="Dataset preview", visible=False) |
|
with gr.Tab("Generate from prompt") as tab_prompt_input: |
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=2): |
|
dataset_description = gr.Textbox( |
|
label="Dataset description", |
|
placeholder="Give a precise description of your desired dataset.", |
|
) |
|
with gr.Row(): |
|
clear_prompt_btn_part = gr.Button( |
|
"Clear", variant="secondary" |
|
) |
|
load_prompt_btn = gr.Button("Create", variant="primary") |
|
with gr.Column(scale=3): |
|
examples = gr.Examples( |
|
examples=DEFAULT_DATASET_DESCRIPTIONS, |
|
inputs=[dataset_description], |
|
cache_examples=False, |
|
label="Examples", |
|
) |
|
|
|
gr.HTML(value="<hr>") |
|
gr.Markdown(value="## 2. Configure your task") |
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=2): |
|
system_prompt = gr.Textbox( |
|
label="System prompt", |
|
placeholder="You are a helpful assistant.", |
|
visible=False, |
|
) |
|
document_column = gr.Dropdown( |
|
label="Document Column", |
|
info="Select the document column to generate the RAG dataset", |
|
choices=["Load your data first in step 1."], |
|
value="Load your data first in step 1.", |
|
interactive=False, |
|
multiselect=False, |
|
allow_custom_value=False, |
|
) |
|
retrieval_reranking = gr.CheckboxGroup( |
|
choices=[("Retrieval", "Retrieval"), ("Reranking", "Reranking")], |
|
type="value", |
|
label="Data for RAG", |
|
info="Indicate the additional data you want to generate for RAG.", |
|
) |
|
with gr.Row(): |
|
clear_btn_full = gr.Button("Clear", variant="secondary") |
|
btn_apply_to_sample_dataset = gr.Button("Save", variant="primary") |
|
with gr.Column(scale=3): |
|
dataframe = gr.Dataframe( |
|
headers=["context", "question", "response"], |
|
wrap=True, |
|
interactive=False, |
|
) |
|
|
|
gr.HTML(value="<hr>") |
|
gr.Markdown(value="## 3. Generate your dataset") |
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=2): |
|
org_name = get_org_dropdown() |
|
repo_name = gr.Textbox( |
|
label="Repo name", |
|
placeholder="dataset_name", |
|
value=f"my-distiset-{str(uuid.uuid4())[:8]}", |
|
interactive=True, |
|
) |
|
num_rows = gr.Number( |
|
label="Number of rows", |
|
value=10, |
|
interactive=True, |
|
scale=1, |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.1, |
|
maximum=1.5, |
|
value=0.7, |
|
step=0.1, |
|
interactive=True, |
|
) |
|
temperature_completion = gr.Slider( |
|
label="Temperature for completion", |
|
minimum=0.1, |
|
maximum=1.5, |
|
value=None, |
|
step=0.1, |
|
interactive=True, |
|
visible=False, |
|
) |
|
private = gr.Checkbox( |
|
label="Private dataset", |
|
value=False, |
|
interactive=True, |
|
scale=1, |
|
) |
|
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2) |
|
btn_save_local = gr.Button( |
|
"Save locally", variant="primary", scale=2, visible=False |
|
) |
|
with gr.Column(scale=3): |
|
csv_file = gr.File( |
|
label="CSV", |
|
elem_classes="datasets", |
|
visible=False, |
|
) |
|
json_file = gr.File( |
|
label="JSON", |
|
elem_classes="datasets", |
|
visible=False, |
|
) |
|
success_message = gr.Markdown( |
|
visible=False, |
|
min_height=0, |
|
) |
|
with gr.Accordion( |
|
"Customize your pipeline with distilabel", |
|
open=False, |
|
visible=False, |
|
) as pipeline_code_ui: |
|
code = generate_pipeline_code( |
|
repo_id=search_in.value, |
|
input_type=input_type.value, |
|
system_prompt=system_prompt.value, |
|
document_column=document_column.value, |
|
retrieval_reranking=retrieval_reranking.value, |
|
num_rows=num_rows.value, |
|
) |
|
pipeline_code = gr.Code( |
|
value=code, |
|
language="python", |
|
label="Distilabel Pipeline Code", |
|
) |
|
|
|
tab_dataset_input.select( |
|
fn=lambda: "dataset-input", |
|
inputs=[], |
|
outputs=[input_type], |
|
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then( |
|
fn=show_document_column_visibility, inputs=[], outputs=[document_column] |
|
) |
|
|
|
tab_file_input.select( |
|
fn=lambda: "file-input", |
|
inputs=[], |
|
outputs=[input_type], |
|
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then( |
|
fn=show_document_column_visibility, inputs=[], outputs=[document_column] |
|
) |
|
|
|
tab_prompt_input.select( |
|
fn=lambda: "prompt-input", |
|
inputs=[], |
|
outputs=[input_type], |
|
).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then( |
|
fn=hide_document_column_visibility, inputs=[], outputs=[document_column] |
|
) |
|
|
|
search_in.submit( |
|
fn=lambda df: pd.DataFrame(columns=df.columns), |
|
inputs=[dataframe], |
|
outputs=[dataframe], |
|
) |
|
|
|
gr.on( |
|
triggers=[load_dataset_btn.click, load_file_btn.click], |
|
fn=load_dataset_file, |
|
inputs=[search_in, file_in, input_type], |
|
outputs=[dataframe, document_column], |
|
) |
|
|
|
load_prompt_btn.click( |
|
fn=generate_system_prompt, |
|
inputs=[dataset_description], |
|
outputs=[system_prompt], |
|
).success( |
|
fn=generate_sample_dataset, |
|
inputs=[ |
|
search_in, |
|
file_in, |
|
input_type, |
|
system_prompt, |
|
document_column, |
|
retrieval_reranking, |
|
num_rows, |
|
], |
|
outputs=dataframe, |
|
) |
|
|
|
btn_apply_to_sample_dataset.click( |
|
fn=generate_sample_dataset, |
|
inputs=[ |
|
search_in, |
|
file_in, |
|
input_type, |
|
system_prompt, |
|
document_column, |
|
retrieval_reranking, |
|
num_rows, |
|
], |
|
outputs=dataframe, |
|
) |
|
|
|
btn_push_to_hub.click( |
|
fn=validate_argilla_user_workspace_dataset, |
|
inputs=[repo_name], |
|
outputs=[success_message], |
|
).then( |
|
fn=validate_push_to_hub, |
|
inputs=[org_name, repo_name], |
|
outputs=[success_message], |
|
).success( |
|
fn=hide_save_local, |
|
outputs=[csv_file, json_file, success_message], |
|
).success( |
|
fn=hide_success_message, |
|
outputs=[success_message], |
|
).success( |
|
fn=hide_pipeline_code_visibility, |
|
inputs=[], |
|
outputs=[pipeline_code_ui], |
|
).success( |
|
fn=push_dataset, |
|
inputs=[ |
|
org_name, |
|
repo_name, |
|
private, |
|
search_in, |
|
file_in, |
|
input_type, |
|
system_prompt, |
|
document_column, |
|
retrieval_reranking, |
|
num_rows, |
|
temperature, |
|
temperature_completion, |
|
pipeline_code, |
|
], |
|
outputs=[success_message], |
|
).success( |
|
fn=show_success_message, |
|
inputs=[org_name, repo_name], |
|
outputs=[success_message], |
|
).success( |
|
fn=generate_pipeline_code, |
|
inputs=[ |
|
search_in, |
|
input_type, |
|
system_prompt, |
|
document_column, |
|
retrieval_reranking, |
|
num_rows, |
|
], |
|
outputs=[pipeline_code], |
|
).success( |
|
fn=show_pipeline_code_visibility, |
|
inputs=[], |
|
outputs=[pipeline_code_ui], |
|
) |
|
|
|
btn_save_local.click( |
|
fn=hide_success_message, |
|
outputs=[success_message], |
|
).success( |
|
fn=hide_pipeline_code_visibility, |
|
inputs=[], |
|
outputs=[pipeline_code_ui], |
|
).success( |
|
fn=show_save_local, |
|
inputs=[], |
|
outputs=[csv_file, json_file, success_message], |
|
).success( |
|
save_local, |
|
inputs=[ |
|
search_in, |
|
file_in, |
|
input_type, |
|
system_prompt, |
|
document_column, |
|
retrieval_reranking, |
|
num_rows, |
|
temperature, |
|
repo_name, |
|
temperature_completion, |
|
], |
|
outputs=[csv_file, json_file], |
|
).success( |
|
fn=generate_pipeline_code, |
|
inputs=[ |
|
search_in, |
|
input_type, |
|
system_prompt, |
|
document_column, |
|
retrieval_reranking, |
|
num_rows, |
|
], |
|
outputs=[pipeline_code], |
|
).success( |
|
fn=show_pipeline_code_visibility, |
|
inputs=[], |
|
outputs=[pipeline_code_ui], |
|
) |
|
|
|
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in]) |
|
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in]) |
|
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description]) |
|
clear_btn_full.click( |
|
fn=lambda df: ("", [], pd.DataFrame(columns=df.columns)), |
|
inputs=[dataframe], |
|
outputs=[document_column, retrieval_reranking, dataframe], |
|
) |
|
|
|
app.load(fn=swap_visibility, outputs=main_ui) |
|
app.load(fn=get_org_dropdown, outputs=[org_name]) |
|
app.load(fn=get_random_repo_name, outputs=[repo_name]) |
|
app.load(fn=show_temperature_completion, outputs=[temperature_completion]) |
|
if SAVE_LOCAL_DIR is not None: |
|
app.load(fn=show_save_local_button, outputs=btn_save_local) |
|
|