import ast import json import os import random import uuid from typing import Dict, List, Union import argilla as rg import gradio as gr import pandas as pd from datasets import Dataset from distilabel.distiset import Distiset from gradio.oauth import OAuthToken from gradio_huggingfacehub_search import HuggingfaceHubSearch from huggingface_hub import HfApi from synthetic_dataset_generator.apps.base import ( combine_datasets, hide_success_message, load_dataset_from_hub, preprocess_input_data, push_pipeline_code_to_hub, show_success_message, test_max_num_rows, validate_argilla_user_workspace_dataset, validate_push_to_hub, ) from synthetic_dataset_generator.constants import ( BASE_URL, DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION, SAVE_LOCAL_DIR, SFT_AVAILABLE, ) from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts from synthetic_dataset_generator.pipelines.chat import ( DEFAULT_DATASET_DESCRIPTIONS, generate_pipeline_code, get_follow_up_generator, get_magpie_generator, get_prompt_generator, get_response_generator, get_sentence_pair_generator, ) from synthetic_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, ) from synthetic_dataset_generator.utils import ( column_to_list, get_argilla_client, get_org_dropdown, get_random_repo_name, swap_visibility, ) def _get_dataframe(): return gr.Dataframe( headers=["prompt", "completion"], wrap=True, interactive=False, ) def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame: def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]: return ast.literal_eval( messages.replace("'user'}", "'user'},") .replace("'system'}", "'system'},") .replace("'assistant'}", "'assistant'},") ) if "messages" in dataframe.columns: dataframe["messages"] = dataframe["messages"].apply( lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x ) return dataframe def generate_system_prompt(dataset_description: str, progress=gr.Progress()): progress(0.1, desc="Initializing") generate_description = get_prompt_generator() progress(0.5, desc="Generating") result = next( generate_description.process( [ { "instruction": dataset_description, } ] ) )[0]["generation"] progress(1.0, desc="Prompt generated") return result def load_dataset_file( repo_id: str, file_paths: list[str], input_type: str, num_rows: int = 10, token: Union[OAuthToken, None] = None, progress=gr.Progress(), ): progress(0.1, desc="Loading the source data") if input_type == "dataset-input": return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token) else: return preprocess_input_data(file_paths=file_paths, num_rows=num_rows) def generate_sample_dataset( repo_id: str, file_paths: list[str], input_type: str, system_prompt: str, document_column: str, num_turns: int, num_rows: int, oauth_token: Union[OAuthToken, None], progress=gr.Progress(), ): if input_type == "prompt-input": dataframe = pd.DataFrame(columns=["prompt", "completion"]) else: dataframe, _ = load_dataset_file( repo_id=repo_id, file_paths=file_paths, input_type=input_type, num_rows=num_rows, token=oauth_token, ) progress(0.5, desc="Generating sample dataset") dataframe = generate_dataset( input_type=input_type, dataframe=dataframe, system_prompt=system_prompt, document_column=document_column, num_turns=num_turns, num_rows=num_rows, is_sample=True, ) progress(1.0, desc="Sample dataset generated") return dataframe def generate_dataset_from_prompt( system_prompt: str, num_turns: int = 1, num_rows: int = 10, temperature: float = 0.9, temperature_completion: Union[float, None] = None, is_sample: bool = False, progress=gr.Progress(), ) -> pd.DataFrame: num_rows = test_max_num_rows(num_rows) progress(0.0, desc="(1/2) Generating instructions") magpie_generator = get_magpie_generator(num_turns, temperature, is_sample) response_generator = get_response_generator( system_prompt=system_prompt, num_turns=num_turns, temperature=temperature or temperature_completion, is_sample=is_sample, ) total_steps: int = num_rows * 2 batch_size = DEFAULT_BATCH_SIZE # create prompt rewrites prompt_rewrites = get_rewritten_prompts(system_prompt, num_rows) # create instructions n_processed = 0 magpie_results = [] while n_processed < num_rows: progress( 0.5 * n_processed / num_rows, total=total_steps, desc="(1/2) Generating instructions", ) remaining_rows = num_rows - n_processed batch_size = min(batch_size, remaining_rows) rewritten_system_prompt = random.choice(prompt_rewrites) inputs = [{"system_prompt": rewritten_system_prompt} for _ in range(batch_size)] batch = list(magpie_generator.process(inputs=inputs)) magpie_results.extend(batch[0]) n_processed += batch_size random.seed(a=random.randint(0, 2**32 - 1)) progress(0.5, desc="(1/2) Generating instructions") # generate responses n_processed = 0 response_results = [] if num_turns == 1: while n_processed < num_rows: progress( 0.5 + 0.5 * n_processed / num_rows, total=total_steps, desc="(2/2) Generating responses", ) batch = magpie_results[n_processed : n_processed + batch_size] responses = list(response_generator.process(inputs=batch)) response_results.extend(responses[0]) n_processed += batch_size random.seed(a=random.randint(0, 2**32 - 1)) for result in response_results: result["prompt"] = result["instruction"] result["completion"] = result["generation"] result["system_prompt"] = system_prompt else: for result in magpie_results: result["conversation"].insert( 0, {"role": "system", "content": system_prompt} ) result["messages"] = result["conversation"] while n_processed < num_rows: progress( 0.5 + 0.5 * n_processed / num_rows, total=total_steps, desc="(2/2) Generating responses", ) batch = magpie_results[n_processed : n_processed + batch_size] responses = list(response_generator.process(inputs=batch)) response_results.extend(responses[0]) n_processed += batch_size random.seed(a=random.randint(0, 2**32 - 1)) for result in response_results: result["messages"].append( {"role": "assistant", "content": result["generation"]} ) progress( 1, total=total_steps, desc="(2/2) Creating dataset", ) # create distiset distiset_results = [] for result in response_results: record = {} for relevant_keys in [ "messages", "prompt", "completion", "model_name", "system_prompt", ]: if relevant_keys in result: record[relevant_keys] = result[relevant_keys] distiset_results.append(record) distiset = Distiset( { "default": Dataset.from_list(distiset_results), } ) # If not pushing to hub generate the dataset directly distiset = distiset["default"] if num_turns == 1: outputs = distiset.to_pandas()[["prompt", "completion", "system_prompt"]] else: outputs = distiset.to_pandas()[["messages"]] dataframe = pd.DataFrame(outputs) progress(1.0, desc="Dataset generation completed") return dataframe def generate_dataset_from_seed( dataframe: pd.DataFrame, document_column: str, num_turns: int = 1, num_rows: int = 10, temperature: float = 0.9, temperature_completion: Union[float, None] = None, is_sample: bool = False, progress=gr.Progress(), ) -> pd.DataFrame: num_rows = test_max_num_rows(num_rows) progress(0.0, desc="Initializing dataset generation") document_data = column_to_list(dataframe, document_column) if len(document_data) < num_rows: document_data += random.choices(document_data, k=num_rows - len(document_data)) instruction_generator = get_sentence_pair_generator( temperature=temperature, is_sample=is_sample ) response_generator = get_response_generator( system_prompt=None, num_turns=1, temperature=temperature or temperature_completion, is_sample=is_sample, ) follow_up_generator_instruction = get_follow_up_generator( type="instruction", temperature=temperature, is_sample=is_sample ) follow_up_generator_response = get_follow_up_generator( type="response", temperature=temperature or temperature_completion, is_sample=is_sample, ) steps = 2 * num_turns total_steps: int = num_rows * steps step_progress = round(1 / steps, 2) batch_size = DEFAULT_BATCH_SIZE # create instructions n_processed = 0 instruction_results = [] while n_processed < num_rows: progress( step_progress * n_processed / num_rows, total=total_steps, desc="Generating instructions", ) remaining_rows = num_rows - n_processed batch_size = min(batch_size, remaining_rows) batch = [ {"anchor": document} for document in document_data[n_processed : n_processed + batch_size] ] questions = list(instruction_generator.process(inputs=batch)) instruction_results.extend(questions[0]) n_processed += batch_size for result in instruction_results: result["instruction"] = result["positive"] result["prompt"] = result.pop("positive") progress(step_progress, desc="Generating instructions") # generate responses n_processed = 0 response_results = [] while n_processed < num_rows: progress( step_progress + step_progress * n_processed / num_rows, total=total_steps, desc="Generating responses", ) batch = instruction_results[n_processed : n_processed + batch_size] responses = list(response_generator.process(inputs=batch)) response_results.extend(responses[0]) n_processed += batch_size for result in response_results: result["completion"] = result.pop("generation") # generate follow-ups if num_turns > 1: n_processed = 0 final_conversations = [] while n_processed < num_rows: progress( step_progress + step_progress * n_processed / num_rows, total=total_steps, desc="Generating follow-ups", ) batch = response_results[n_processed : n_processed + batch_size] conversations_batch = [ { "messages": [ {"role": "user", "content": result["prompt"]}, {"role": "assistant", "content": result["completion"]}, ] } for result in batch ] for _ in range(num_turns - 1): follow_up_instructions = list( follow_up_generator_instruction.process(inputs=conversations_batch) ) for conv, follow_up in zip( conversations_batch, follow_up_instructions[0] ): conv["messages"].append( {"role": "user", "content": follow_up["generation"]} ) follow_up_responses = list( follow_up_generator_response.process(inputs=conversations_batch) ) for conv, follow_up in zip(conversations_batch, follow_up_responses[0]): conv["messages"].append( {"role": "assistant", "content": follow_up["generation"]} ) final_conversations.extend( [{"messages": conv["messages"]} for conv in conversations_batch] ) n_processed += batch_size # create distiset distiset_results = [] if num_turns == 1: for result in response_results: record = {} for relevant_keys in ["prompt", "completion"]: if relevant_keys in result: record[relevant_keys] = result[relevant_keys] distiset_results.append(record) dataframe = pd.DataFrame(distiset_results) else: distiset_results = final_conversations dataframe = pd.DataFrame(distiset_results) dataframe["messages"] = dataframe["messages"].apply(lambda x: json.dumps(x)) progress(1.0, desc="Dataset generation completed") return dataframe def generate_dataset( input_type: str, dataframe: pd.DataFrame, system_prompt: str, document_column: str, num_turns: int = 1, num_rows: int = 10, temperature: float = 0.9, temperature_completion: Union[float, None] = None, is_sample: bool = False, progress=gr.Progress(), ) -> pd.DataFrame: if input_type == "prompt-input": dataframe = generate_dataset_from_prompt( system_prompt=system_prompt, num_turns=num_turns, num_rows=num_rows, temperature=temperature, temperature_completion=temperature_completion, is_sample=is_sample, ) else: dataframe = generate_dataset_from_seed( dataframe=dataframe, document_column=document_column, num_turns=num_turns, num_rows=num_rows, temperature=temperature, temperature_completion=temperature_completion, is_sample=is_sample, ) return dataframe def push_dataset_to_hub( dataframe: pd.DataFrame, org_name: str, repo_name: str, oauth_token: Union[gr.OAuthToken, None], private: bool, pipeline_code: str, progress=gr.Progress(), ): progress(0.0, desc="Validating") repo_id = validate_push_to_hub(org_name, repo_name) progress(0.3, desc="Converting") original_dataframe = dataframe.copy(deep=True) dataframe = convert_dataframe_messages(dataframe) progress(0.7, desc="Creating dataset") dataset = Dataset.from_pandas(dataframe) dataset = combine_datasets(repo_id, dataset, oauth_token) progress(0.9, desc="Pushing dataset") distiset = Distiset({"default": dataset}) distiset.push_to_hub( repo_id=repo_id, private=private, include_script=False, token=oauth_token.token, create_pr=False, ) push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token) progress(1.0, desc="Dataset pushed") return original_dataframe def push_dataset( org_name: str, repo_name: str, private: bool, original_repo_id: str, file_paths: list[str], input_type: str, system_prompt: str, document_column: str, num_turns: int = 1, num_rows: int = 10, temperature: float = 0.9, temperature_completion: Union[float, None] = None, pipeline_code: str = "", oauth_token: Union[gr.OAuthToken, None] = None, progress=gr.Progress(), ) -> pd.DataFrame: if input_type == "prompt-input": dataframe = _get_dataframe() else: dataframe, _ = load_dataset_file( repo_id=original_repo_id, file_paths=file_paths, input_type=input_type, num_rows=num_rows, token=oauth_token, ) progress(0.5, desc="Generating dataset") dataframe = generate_dataset( input_type=input_type, dataframe=dataframe, system_prompt=system_prompt, document_column=document_column, num_turns=num_turns, num_rows=num_rows, temperature=temperature, temperature_completion=temperature_completion, ) push_dataset_to_hub( dataframe=dataframe, org_name=org_name, repo_name=repo_name, oauth_token=oauth_token, private=private, pipeline_code=pipeline_code, ) try: progress(0.1, desc="Setting up user and workspace") hf_user = HfApi().whoami(token=oauth_token.token)["name"] client = get_argilla_client() if client is None: return "" progress(0.5, desc="Creating dataset in Argilla") if "messages" in dataframe.columns: settings = rg.Settings( fields=[ rg.ChatField( name="messages", description="The messages in the conversation", title="Messages", ), ], questions=[ rg.RatingQuestion( name="rating", title="Rating", description="The rating of the conversation", values=list(range(1, 6)), ), ], metadata=[ rg.IntegerMetadataProperty( name="user_message_length", title="User Message Length" ), rg.IntegerMetadataProperty( name="assistant_message_length", title="Assistant Message Length", ), ], vectors=[ rg.VectorField( name="messages_embeddings", dimensions=get_sentence_embedding_dimensions(), ) ], guidelines="Please review the conversation and provide a score for the assistant's response.", ) dataframe["user_message_length"] = dataframe["messages"].apply( lambda x: sum([len(y["content"]) for y in x if y["role"] == "user"]) ) dataframe["assistant_message_length"] = dataframe["messages"].apply( lambda x: sum( [len(y["content"]) for y in x if y["role"] == "assistant"] ) ) dataframe["messages_embeddings"] = get_embeddings( dataframe["messages"].apply( lambda x: " ".join([y["content"] for y in x]) ) ) else: settings = rg.Settings( fields=[ rg.TextField( name="system_prompt", title="System Prompt", description="The system prompt used for the conversation", required=False, ), rg.TextField( name="prompt", title="Prompt", description="The prompt used for the conversation", ), rg.TextField( name="completion", title="Completion", description="The completion from the assistant", ), ], questions=[ rg.RatingQuestion( name="rating", title="Rating", description="The rating of the conversation", values=list(range(1, 6)), ), ], metadata=[ rg.IntegerMetadataProperty( name="prompt_length", title="Prompt Length" ), rg.IntegerMetadataProperty( name="completion_length", title="Completion Length" ), ], vectors=[ rg.VectorField( name="prompt_embeddings", dimensions=get_sentence_embedding_dimensions(), ) ], guidelines="Please review the conversation and correct the prompt and completion where needed.", ) dataframe["prompt_length"] = dataframe["prompt"].apply(len) dataframe["completion_length"] = dataframe["completion"].apply(len) dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"]) rg_dataset = client.datasets(name=repo_name, workspace=hf_user) if rg_dataset is None: rg_dataset = rg.Dataset( name=repo_name, workspace=hf_user, settings=settings, client=client, ) rg_dataset = rg_dataset.create() progress(0.7, desc="Pushing dataset to Argilla") hf_dataset = Dataset.from_pandas(dataframe) rg_dataset.records.log(records=hf_dataset) progress(1.0, desc="Dataset pushed to Argilla") except Exception as e: raise gr.Error(f"Error pushing dataset to Argilla: {e}") return "" def save_local( repo_id: str, file_paths: list[str], input_type: str, system_prompt: str, document_column: str, num_turns: int, num_rows: int, temperature: float, repo_name: str, temperature_completion: Union[float, None] = None, ) -> pd.DataFrame: if input_type == "prompt-input": dataframe = _get_dataframe() else: dataframe, _ = load_dataset_file( repo_id=repo_id, file_paths=file_paths, input_type=input_type, num_rows=num_rows, ) dataframe = generate_dataset( input_type=input_type, dataframe=dataframe, system_prompt=system_prompt, document_column=document_column, num_turns=num_turns, num_rows=num_rows, temperature=temperature, temperature_completion=temperature_completion, ) local_dataset = Dataset.from_pandas(dataframe) output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv") output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json") local_dataset.to_csv(output_csv, index=False) local_dataset.to_json(output_json, index=False) return output_csv, output_json def show_system_prompt_visibility(): return {system_prompt: gr.Textbox(visible=True)} def hide_system_prompt_visibility(): return {system_prompt: gr.Textbox(visible=False)} def show_document_column_visibility(): return {document_column: gr.Dropdown(visible=True)} def hide_document_column_visibility(): return { document_column: gr.Dropdown( choices=["Load your data first in step 1."], value="Load your data first in step 1.", visible=False, ) } def show_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=True)} def hide_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=False)} def show_temperature_completion(): if MODEL != MODEL_COMPLETION: return {temperature_completion: gr.Slider(value=0.9, visible=True)} def show_save_local_button(): return {btn_save_local: gr.Button(visible=True)} def hide_save_local_button(): return {btn_save_local: gr.Button(visible=False)} def show_save_local(): gr.update(success_message, min_height=0) return { csv_file: gr.File(visible=True), json_file: gr.File(visible=True), success_message: success_message } def hide_save_local(): gr.update(success_message, min_height=100) return { csv_file: gr.File(visible=False), json_file: gr.File(visible=False), success_message: success_message, } ###################### # Gradio UI ###################### with gr.Blocks() as app: with gr.Column() as main_ui: if not SFT_AVAILABLE: gr.Markdown( value="\n".join( [ "## Supervised Fine-Tuning not available", "", f"This tool relies on the [Magpie](https://arxiv.org/abs/2406.08464) prequery template, which is not implemented for the {MODEL} with {BASE_URL}.", "Use Llama3 or Qwen2 models with Hugging Face Inference Endpoints.", ] ) ) else: gr.Markdown("## 1. Select your input") with gr.Row(equal_height=False): with gr.Column(scale=2): input_type = gr.Dropdown( label="Input type", choices=["prompt-input", "dataset-input", "file-input"], value="prompt-input", multiselect=False, visible=False, ) with gr.Tab("Generate from prompt") as tab_prompt_input: with gr.Row(equal_height=False): with gr.Column(scale=2): dataset_description = gr.Textbox( label="Dataset description", placeholder="Give a precise description of your desired dataset.", ) with gr.Row(): clear_prompt_btn_part = gr.Button( "Clear", variant="secondary" ) load_prompt_btn = gr.Button( "Create", variant="primary" ) with gr.Column(scale=3): examples = gr.Examples( examples=DEFAULT_DATASET_DESCRIPTIONS, inputs=[dataset_description], cache_examples=False, label="Examples", ) with gr.Tab("Load from Hub") as tab_dataset_input: with gr.Row(equal_height=False): with gr.Column(scale=2): search_in = HuggingfaceHubSearch( label="Search", placeholder="Search for a dataset", search_type="dataset", sumbit_on_select=True, ) with gr.Row(): clear_dataset_btn_part = gr.Button( "Clear", variant="secondary" ) load_dataset_btn = gr.Button( "Load", variant="primary" ) with gr.Column(scale=3): examples = gr.Examples( examples=[ "charris/wikipedia_sample", "plaguss/argilla_sdk_docs_raw_unstructured", "BeIR/hotpotqa-generated-queries", ], label="Example datasets", fn=lambda x: x, inputs=[search_in], run_on_click=True, ) search_out = gr.HTML( label="Dataset preview", visible=False ) with gr.Tab("Load your file") as tab_file_input: with gr.Row(equal_height=False): with gr.Column(scale=2): file_in = gr.File( label="Upload your file. Supported formats: .md, .txt, .docx, .pdf", file_count="multiple", file_types=[".md", ".txt", ".docx", ".pdf"], ) with gr.Row(): clear_file_btn_part = gr.Button( "Clear", variant="secondary" ) load_file_btn = gr.Button("Load", variant="primary") with gr.Column(scale=3): file_out = gr.HTML( label="Dataset preview", visible=False ) gr.HTML(value="
") gr.Markdown(value="## 2. Configure your dataset") with gr.Row(equal_height=False): with gr.Column(scale=2): system_prompt = gr.Textbox( label="System prompt", placeholder="You are a helpful assistant.", ) document_column = gr.Dropdown( label="Document Column", info="Select the document column to generate the chat data", choices=["Load your data first in step 1."], value="Load your data first in step 1.", interactive=False, multiselect=False, allow_custom_value=False, visible=False, ) num_turns = gr.Number( value=1, label="Number of turns in the conversation", minimum=1, maximum=4, step=1, interactive=True, info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).", ) with gr.Row(): clear_btn_full = gr.Button( "Clear", variant="secondary", ) btn_apply_to_sample_dataset = gr.Button( "Save", variant="primary" ) with gr.Column(scale=3): dataframe = _get_dataframe() gr.HTML(value="
") gr.Markdown(value="## 3. Generate your dataset") with gr.Row(equal_height=False): with gr.Column(scale=2): org_name = get_org_dropdown() repo_name = gr.Textbox( label="Repo name", placeholder="dataset_name", value=f"my-distiset-{str(uuid.uuid4())[:8]}", interactive=True, ) num_rows = gr.Number( label="Number of rows", value=10, interactive=True, scale=1, ) temperature = gr.Slider( label="Temperature", minimum=0.1, maximum=1.5, value=0.9, step=0.1, interactive=True, ) temperature_completion = gr.Slider( label="Temperature for completion", minimum=0.1, maximum=1.5, value=None, step=0.1, interactive=True, visible=False, ) private = gr.Checkbox( label="Private dataset", value=False, interactive=True, scale=1, ) btn_push_to_hub = gr.Button( "Push to Hub", variant="primary", scale=2 ) btn_save_local = gr.Button( "Save locally", variant="primary", scale=2, visible=False ) with gr.Column(scale=3): csv_file = gr.File( label="CSV", elem_classes="datasets", visible=False, ) json_file = gr.File( label="JSON", elem_classes="datasets", visible=False, ) success_message = gr.Markdown( visible=False, min_height=0 # don't remove this otherwise progress is not visible ) with gr.Accordion( "Customize your pipeline with distilabel", open=False, visible=False, ) as pipeline_code_ui: code = generate_pipeline_code( repo_id=search_in.value, input_type=input_type.value, system_prompt=system_prompt.value, document_column=document_column.value, num_turns=num_turns.value, num_rows=num_rows.value, ) pipeline_code = gr.Code( value=code, language="python", label="Distilabel Pipeline Code", ) tab_prompt_input.select( fn=lambda: "prompt-input", inputs=[], outputs=[input_type], ).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then( fn=hide_document_column_visibility, inputs=[], outputs=[document_column] ) tab_dataset_input.select( fn=lambda: "dataset-input", inputs=[], outputs=[input_type], ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then( fn=show_document_column_visibility, inputs=[], outputs=[document_column] ) tab_file_input.select( fn=lambda: "file-input", inputs=[], outputs=[input_type], ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then( fn=show_document_column_visibility, inputs=[], outputs=[document_column] ) search_in.submit( fn=lambda df: pd.DataFrame(columns=df.columns), inputs=[dataframe], outputs=[dataframe], ) load_prompt_btn.click( fn=generate_system_prompt, inputs=[dataset_description], outputs=[system_prompt], ).success( fn=generate_sample_dataset, inputs=[ search_in, file_in, input_type, system_prompt, document_column, num_turns, num_rows, ], outputs=dataframe, ) gr.on( triggers=[load_dataset_btn.click, load_file_btn.click], fn=load_dataset_file, inputs=[search_in, file_in, input_type], outputs=[dataframe, document_column], ) btn_apply_to_sample_dataset.click( fn=generate_sample_dataset, inputs=[ search_in, file_in, input_type, system_prompt, document_column, num_turns, num_rows, ], outputs=dataframe, ) btn_push_to_hub.click( fn=validate_argilla_user_workspace_dataset, inputs=[repo_name], outputs=[success_message], ).then( fn=validate_push_to_hub, inputs=[org_name, repo_name], outputs=[success_message], ).success( fn=hide_save_local, outputs=[csv_file, json_file, success_message], ).success( fn=hide_success_message, outputs=[success_message], ).success( fn=hide_pipeline_code_visibility, inputs=[], outputs=[pipeline_code_ui], ).success( fn=push_dataset, inputs=[ org_name, repo_name, private, search_in, file_in, input_type, system_prompt, document_column, num_turns, num_rows, temperature, temperature_completion, pipeline_code, ], outputs=[success_message], ).success( fn=show_success_message, inputs=[org_name, repo_name], outputs=[success_message], ).success( fn=generate_pipeline_code, inputs=[ search_in, input_type, system_prompt, document_column, num_turns, num_rows, ], outputs=[pipeline_code], ).success( fn=show_pipeline_code_visibility, inputs=[], outputs=[pipeline_code_ui], ) btn_save_local.click( fn=hide_success_message, outputs=[success_message], ).success( fn=hide_pipeline_code_visibility, inputs=[], outputs=[pipeline_code_ui], ).success( fn=show_save_local, inputs=[], outputs=[csv_file, json_file, success_message], ).success( save_local, inputs=[ search_in, file_in, input_type, system_prompt, document_column, num_turns, num_rows, temperature, repo_name, temperature_completion, ], outputs=[csv_file, json_file], ).success( fn=generate_pipeline_code, inputs=[ search_in, input_type, system_prompt, document_column, num_turns, num_rows, ], outputs=[pipeline_code], ).success( fn=show_pipeline_code_visibility, inputs=[], outputs=[pipeline_code_ui], ) clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in]) clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in]) clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description]) clear_btn_full.click( fn=lambda df: ("", "", [], _get_dataframe()), inputs=[dataframe], outputs=[system_prompt, document_column, num_turns, dataframe], ) app.load(fn=swap_visibility, outputs=main_ui) app.load(fn=get_org_dropdown, outputs=[org_name]) app.load(fn=get_random_repo_name, outputs=[repo_name]) app.load(fn=show_temperature_completion, outputs=[temperature_completion]) if SAVE_LOCAL_DIR is not None: app.load(fn=show_save_local_button, outputs=btn_save_local)