import os import random import uuid from typing import Union import argilla as rg import gradio as gr import nltk import pandas as pd from datasets import Dataset from distilabel.distiset import Distiset from gradio.oauth import OAuthToken from gradio_huggingfacehub_search import HuggingfaceHubSearch from huggingface_hub import HfApi from synthetic_dataset_generator.apps.base import ( combine_datasets, hide_success_message, load_dataset_from_hub, preprocess_input_data, push_pipeline_code_to_hub, show_success_message, test_max_num_rows, validate_argilla_user_workspace_dataset, validate_push_to_hub, ) from synthetic_dataset_generator.constants import ( DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION, SAVE_LOCAL_DIR, ) from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts from synthetic_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, ) from synthetic_dataset_generator.pipelines.rag import ( DEFAULT_DATASET_DESCRIPTIONS, generate_pipeline_code, get_chunks_generator, get_prompt_generator, get_response_generator, get_sentence_pair_generator, ) from synthetic_dataset_generator.utils import ( column_to_list, get_argilla_client, get_org_dropdown, get_random_repo_name, swap_visibility, ) os.makedirs("./nltk_data", exist_ok=True) nltk.data.path.append("./nltk_data") nltk.download("punkt_tab", download_dir="./nltk_data") nltk.download("averaged_perceptron_tagger_eng", download_dir="./nltk_data") def generate_system_prompt(dataset_description: str, progress=gr.Progress()): progress(0.1, desc="Initializing") generate_description = get_prompt_generator() progress(0.5, desc="Generating") result = next( generate_description.process( [ { "instruction": dataset_description, } ] ) )[0]["generation"] progress(1.0, desc="Prompt generated") return result def load_dataset_file( repo_id: str, file_paths: list[str], input_type: str, num_rows: int = 10, token: Union[OAuthToken, None] = None, progress=gr.Progress(), ): progress(0.1, desc="Loading the source data") if input_type == "dataset-input": return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token) else: return preprocess_input_data(file_paths=file_paths, num_rows=num_rows) def generate_sample_dataset( repo_id: str, file_paths: list[str], input_type: str, system_prompt: str, document_column: str, retrieval_reranking: list[str], num_rows: str, oauth_token: Union[OAuthToken, None], progress=gr.Progress(), ): retrieval = "Retrieval" in retrieval_reranking reranking = "Reranking" in retrieval_reranking if input_type == "prompt-input": dataframe = pd.DataFrame(columns=["context", "question", "response"]) else: dataframe, _ = load_dataset_file( repo_id=repo_id, file_paths=file_paths, input_type=input_type, num_rows=num_rows, token=oauth_token, ) progress(0.5, desc="Generating dataset") dataframe = generate_dataset( input_type=input_type, dataframe=dataframe, system_prompt=system_prompt, document_column=document_column, retrieval=retrieval, reranking=reranking, num_rows=10, is_sample=True, ) progress(1.0, desc="Sample dataset generated") return dataframe def generate_dataset( input_type: str, dataframe: pd.DataFrame, system_prompt: str, document_column: str, retrieval: bool = False, reranking: bool = False, num_rows: int = 10, temperature: float = 0.7, temperature_completion: Union[float, None] = None, is_sample: bool = False, progress=gr.Progress(), ): num_rows = test_max_num_rows(num_rows) progress(0.0, desc="Initializing dataset generation") if input_type == "prompt-input": chunk_generator = get_chunks_generator( temperature=temperature, is_sample=is_sample ) else: document_data = column_to_list(dataframe, document_column) if len(document_data) < num_rows: document_data += random.choices( document_data, k=num_rows - len(document_data) ) retrieval_generator = get_sentence_pair_generator( action="query", triplet=True if retrieval else False, temperature=temperature, is_sample=is_sample, ) response_generator = get_response_generator( temperature=temperature_completion or temperature, is_sample=is_sample ) if reranking: reranking_generator = get_sentence_pair_generator( action="semantically-similar", triplet=True, temperature=temperature, is_sample=is_sample, ) steps = 2 + sum([1 if reranking else 0, 1 if input_type == "prompt-type" else 0]) total_steps: int = num_rows * steps step_progress = round(1 / steps, 2) batch_size = DEFAULT_BATCH_SIZE # generate chunks if input_type == "prompt-input": n_processed = 0 chunk_results = [] rewritten_system_prompts = get_rewritten_prompts(system_prompt, num_rows) while n_processed < num_rows: progress( step_progress * n_processed / num_rows, total=total_steps, desc="Generating chunks", ) remaining_rows = num_rows - n_processed batch_size = min(batch_size, remaining_rows) inputs = [ {"task": random.choice(rewritten_system_prompts)} for _ in range(batch_size) ] chunks = list(chunk_generator.process(inputs=inputs)) chunk_results.extend(chunks[0]) n_processed += batch_size random.seed(a=random.randint(0, 2**32 - 1)) document_data = [chunk["generation"] for chunk in chunk_results] progress(step_progress, desc="Generating chunks") # generate questions n_processed = 0 retrieval_results = [] while n_processed < num_rows: progress( step_progress * n_processed / num_rows, total=total_steps, desc="Generating questions", ) remaining_rows = num_rows - n_processed batch_size = min(batch_size, remaining_rows) inputs = [ {"anchor": document} for document in document_data[n_processed : n_processed + batch_size] ] questions = list(retrieval_generator.process(inputs=inputs)) retrieval_results.extend(questions[0]) n_processed += batch_size for result in retrieval_results: result["context"] = result["anchor"] if retrieval: result["question"] = result["positive"] result["positive_retrieval"] = result.pop("positive") result["negative_retrieval"] = result.pop("negative") else: result["question"] = result.pop("positive") progress(step_progress, desc="Generating questions") # generate responses n_processed = 0 response_results = [] while n_processed < num_rows: progress( step_progress + step_progress * n_processed / num_rows, total=total_steps, desc="Generating responses", ) batch = retrieval_results[n_processed : n_processed + batch_size] responses = list(response_generator.process(inputs=batch)) response_results.extend(responses[0]) n_processed += batch_size for result in response_results: result["response"] = result["generation"] progress(step_progress, desc="Generating responses") # generate reranking if reranking: n_processed = 0 reranking_results = [] while n_processed < num_rows: progress( step_progress * n_processed / num_rows, total=total_steps, desc="Generating reranking data", ) batch = response_results[n_processed : n_processed + batch_size] batch = list(reranking_generator.process(inputs=batch)) reranking_results.extend(batch[0]) n_processed += batch_size for result in reranking_results: result["positive_reranking"] = result.pop("positive") result["negative_reranking"] = result.pop("negative") progress( 1, total=total_steps, desc="Creating dataset", ) # create distiset distiset_results = [] source_results = reranking_results if reranking else response_results base_keys = ["context", "question", "response"] retrieval_keys = ["positive_retrieval", "negative_retrieval"] if retrieval else [] reranking_keys = ["positive_reranking", "negative_reranking"] if reranking else [] relevant_keys = base_keys + retrieval_keys + reranking_keys for result in source_results: record = {key: result.get(key) for key in relevant_keys if key in result} distiset_results.append(record) dataframe = pd.DataFrame(distiset_results) progress(1.0, desc="Dataset generation completed") return dataframe def push_dataset_to_hub( dataframe: pd.DataFrame, org_name: str, repo_name: str, oauth_token: Union[gr.OAuthToken, None], private: bool, pipeline_code: str, progress=gr.Progress(), ): progress(0.0, desc="Validating") repo_id = validate_push_to_hub(org_name, repo_name) progress(0.5, desc="Creating dataset") dataset = Dataset.from_pandas(dataframe) dataset = combine_datasets(repo_id, dataset, oauth_token) distiset = Distiset({"default": dataset}) progress(0.9, desc="Pushing dataset") distiset.push_to_hub( repo_id=repo_id, private=private, include_script=False, token=oauth_token.token, create_pr=False, ) push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token) progress(1.0, desc="Dataset pushed") return dataframe def push_dataset( org_name: str, repo_name: str, private: bool, original_repo_id: str, file_paths: list[str], input_type: str, system_prompt: str, document_column: str, retrieval_reranking: list[str], num_rows: int, temperature: float, temperature_completion: float, pipeline_code: str, oauth_token: Union[gr.OAuthToken, None] = None, progress=gr.Progress(), ) -> pd.DataFrame: retrieval = "Retrieval" in retrieval_reranking reranking = "Reranking" in retrieval_reranking if input_type == "prompt-input": dataframe = pd.DataFrame(columns=["context", "question", "response"]) else: dataframe, _ = load_dataset_file( repo_id=original_repo_id, file_paths=file_paths, input_type=input_type, num_rows=num_rows, token=oauth_token, ) progress(0.5, desc="Generating dataset") dataframe = generate_dataset( input_type=input_type, dataframe=dataframe, system_prompt=system_prompt, document_column=document_column, retrieval=retrieval, reranking=reranking, num_rows=num_rows, temperature=temperature, temperature_completion=temperature_completion, is_sample=True, ) push_dataset_to_hub( dataframe, org_name, repo_name, oauth_token, private, pipeline_code ) dataframe = dataframe[ dataframe.applymap(lambda x: str(x).strip() if pd.notna(x) else x).apply( lambda row: row.notna().all() and (row != "").all(), axis=1 ) ] try: progress(0.1, desc="Setting up user and workspace") hf_user = HfApi().whoami(token=oauth_token.token)["name"] client = get_argilla_client() if client is None: return "" progress(0.5, desc="Creating dataset in Argilla") fields = [ rg.TextField( name="context", title="Context", description="Context for the generation", ), rg.ChatField( name="chat", title="Chat", description="User and assistant conversation based on the context", ), ] for item in ["positive", "negative"]: if retrieval: fields.append( rg.TextField( name=f"{item}_retrieval", title=f"{item.capitalize()} retrieval", description=f"The {item} query for retrieval", ) ) if reranking: fields.append( rg.TextField( name=f"{item}_reranking", title=f"{item.capitalize()} reranking", description=f"The {item} query for reranking", ) ) questions = [ rg.LabelQuestion( name="relevant", title="Are the question and response relevant to the given context?", labels=["yes", "no"], ), rg.LabelQuestion( name="is_response_correct", title="Is the response correct?", labels=["yes", "no"], ), ] for item in ["positive", "negative"]: if retrieval: questions.append( rg.LabelQuestion( name=f"is_{item}_retrieval_relevant", title=f"Is the {item} retrieval relevant?", labels=["yes", "no"], required=False, ) ) if reranking: questions.append( rg.LabelQuestion( name=f"is_{item}_reranking_relevant", title=f"Is the {item} reranking relevant?", labels=["yes", "no"], required=False, ) ) metadata = [ rg.IntegerMetadataProperty( name=f"{item}_length", title=f"{item.capitalize()} length" ) for item in ["context", "question", "response"] ] vectors = [ rg.VectorField( name=f"{item}_embeddings", dimensions=get_sentence_embedding_dimensions(), ) for item in ["context", "question", "response"] ] settings = rg.Settings( fields=fields, questions=questions, metadata=metadata, vectors=vectors, guidelines="Please review the conversation and provide an evaluation.", ) dataframe["chat"] = dataframe.apply( lambda row: [ {"role": "user", "content": row["question"]}, {"role": "assistant", "content": row["response"]}, ], axis=1, ) for item in ["context", "question", "response"]: dataframe[f"{item}_length"] = dataframe[item].apply( lambda x: len(x) if x is not None else 0 ) dataframe[f"{item}_embeddings"] = get_embeddings( dataframe[item].apply(lambda x: x if x is not None else "").to_list() ) rg_dataset = client.datasets(name=repo_name, workspace=hf_user) if rg_dataset is None: rg_dataset = rg.Dataset( name=repo_name, workspace=hf_user, settings=settings, client=client, ) rg_dataset = rg_dataset.create() progress(0.7, desc="Pushing dataset to Argilla") hf_dataset = Dataset.from_pandas(dataframe) rg_dataset.records.log(records=hf_dataset) progress(1.0, desc="Dataset pushed to Argilla") except Exception as e: raise gr.Error(f"Error pushing dataset to Argilla: {e}") return "" def save_local( repo_id: str, file_paths: list[str], input_type: str, system_prompt: str, document_column: str, retrieval_reranking: list[str], num_rows: int, temperature: float, repo_name: str, temperature_completion: float, ) -> pd.DataFrame: retrieval = "Retrieval" in retrieval_reranking reranking = "Reranking" in retrieval_reranking if input_type == "prompt-input": dataframe = pd.DataFrame(columns=["context", "question", "response"]) else: dataframe, _ = load_dataset_file( repo_id=repo_id, file_paths=file_paths, input_type=input_type, num_rows=num_rows, ) dataframe = generate_dataset( input_type=input_type, dataframe=dataframe, system_prompt=system_prompt, document_column=document_column, retrieval=retrieval, reranking=reranking, num_rows=num_rows, temperature=temperature, temperature_completion=temperature_completion, ) local_dataset = Dataset.from_pandas(dataframe) output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv") output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json") local_dataset.to_csv(output_csv, index=False) local_dataset.to_json(output_json, index=False) return output_csv, output_json def show_system_prompt_visibility(): return {system_prompt: gr.Textbox(visible=True)} def hide_system_prompt_visibility(): return {system_prompt: gr.Textbox(visible=False)} def show_document_column_visibility(): return {document_column: gr.Dropdown(visible=True)} def hide_document_column_visibility(): return { document_column: gr.Dropdown( choices=["Load your data first in step 1."], value="Load your data first in step 1.", visible=False, ) } def show_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=True)} def hide_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=False)} def show_temperature_completion(): if MODEL != MODEL_COMPLETION: return {temperature_completion: gr.Slider(value=0.9, visible=True)} def show_save_local_button(): return {btn_save_local: gr.Button(visible=True)} def hide_save_local_button(): return {btn_save_local: gr.Button(visible=False)} def show_save_local(): gr.update(success_message, min_height=0) return { csv_file: gr.File(visible=True), json_file: gr.File(visible=True), success_message: success_message, } def hide_save_local(): gr.update(success_message, min_height=100) return { csv_file: gr.File(visible=False), json_file: gr.File(visible=False), success_message: success_message, } ###################### # Gradio UI ###################### with gr.Blocks() as app: with gr.Column() as main_ui: gr.Markdown("## 1. Select your input") with gr.Row(equal_height=False): with gr.Column(scale=2): input_type = gr.Dropdown( label="Input type", choices=["dataset-input", "file-input", "prompt-input"], value="dataset-input", multiselect=False, visible=False, ) with gr.Tab("Load from Hub") as tab_dataset_input: with gr.Row(equal_height=False): with gr.Column(scale=2): search_in = HuggingfaceHubSearch( label="Search", placeholder="Search for a dataset", search_type="dataset", sumbit_on_select=True, ) with gr.Row(): clear_dataset_btn_part = gr.Button( "Clear", variant="secondary" ) load_dataset_btn = gr.Button("Load", variant="primary") with gr.Column(scale=3): examples = gr.Examples( examples=[ "charris/wikipedia_sample", "plaguss/argilla_sdk_docs_raw_unstructured", "BeIR/hotpotqa-generated-queries", ], label="Example datasets", fn=lambda x: x, inputs=[search_in], run_on_click=True, ) search_out = gr.HTML(label="Dataset preview", visible=False) with gr.Tab("Load your file") as tab_file_input: with gr.Row(equal_height=False): with gr.Column(scale=2): file_in = gr.File( label="Upload your file. Supported formats: .md, .txt, .docx, .pdf", file_count="multiple", file_types=[".md", ".txt", ".docx", ".pdf"], ) with gr.Row(): clear_file_btn_part = gr.Button( "Clear", variant="secondary" ) load_file_btn = gr.Button("Load", variant="primary") with gr.Column(scale=3): file_out = gr.HTML(label="Dataset preview", visible=False) with gr.Tab("Generate from prompt") as tab_prompt_input: with gr.Row(equal_height=False): with gr.Column(scale=2): dataset_description = gr.Textbox( label="Dataset description", placeholder="Give a precise description of your desired dataset.", ) with gr.Row(): clear_prompt_btn_part = gr.Button( "Clear", variant="secondary" ) load_prompt_btn = gr.Button("Create", variant="primary") with gr.Column(scale=3): examples = gr.Examples( examples=DEFAULT_DATASET_DESCRIPTIONS, inputs=[dataset_description], cache_examples=False, label="Examples", ) gr.HTML(value="