sdiazlor's picture
try to fix the requirement
3a74a0a
raw
history blame
31.5 kB
import random
import uuid
from tqdm import tqdm
from typing import Union
import argilla as rg
import gradio as gr
import nltk
import pandas as pd
from datasets import (
Dataset,
get_dataset_config_names,
get_dataset_split_names,
load_dataset,
)
from distilabel.distiset import Distiset
from gradio.oauth import OAuthToken
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from huggingface_hub import HfApi
from unstructured.chunking.title import chunk_by_title
from unstructured.partition.auto import partition
from synthetic_dataset_generator.apps.base import (
combine_datasets,
get_iframe,
hide_success_message,
push_pipeline_code_to_hub,
show_success_message,
test_max_num_rows,
validate_argilla_user_workspace_dataset,
validate_push_to_hub,
)
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
from synthetic_dataset_generator.pipelines.embeddings import (
get_embeddings,
get_sentence_embedding_dimensions,
)
from synthetic_dataset_generator.pipelines.rag import (
DEFAULT_DATASET_DESCRIPTIONS,
get_chunks_generator,
get_prompt_generator,
generate_pipeline_code,
get_sentence_pair_generator,
get_response_generator,
)
from synthetic_dataset_generator.utils import (
column_to_list,
get_argilla_client,
get_org_dropdown,
get_random_repo_name,
swap_visibility,
)
nltk.download("punkt_tab")
nltk.download("averaged_perceptron_tagger_eng")
def _get_valid_columns(dataframe: pd.DataFrame):
doc_valid_columns = []
for col in dataframe.columns:
sample_val = dataframe[col].iloc[0]
if isinstance(sample_val, str):
doc_valid_columns.append(col)
return doc_valid_columns
def _load_dataset_from_hub(
repo_id: str,
num_rows: int = 10,
token: Union[OAuthToken, None] = None,
progress=gr.Progress(track_tqdm=True),
):
if not repo_id:
raise gr.Error("Hub repo id is required")
subsets = get_dataset_config_names(repo_id, token=token)
splits = get_dataset_split_names(repo_id, subsets[0], token=token)
ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
rows = []
for idx, row in enumerate(tqdm(ds, desc="Loading the dataset", total=num_rows)):
rows.append(row)
if idx == num_rows:
break
ds = Dataset.from_list(rows)
dataframe = ds.to_pandas()
doc_valid_columns = _get_valid_columns(dataframe)
col_doc = doc_valid_columns[0] if doc_valid_columns else ""
return (
dataframe,
gr.Dropdown(
choices=doc_valid_columns,
label="Documents column",
value=col_doc,
interactive=(False if col_doc == "" else True),
multiselect=False,
),
)
def _preprocess_input_data(file_paths, num_rows, progress=gr.Progress(track_tqdm=True)):
data = {}
total_chunks = 0
for file_path in tqdm(file_paths, desc="Processing files", total=len(file_paths)):
partitioned_file = partition(filename=file_path)
chunks = [str(chunk) for chunk in chunk_by_title(partitioned_file)]
data[file_path] = chunks
total_chunks += len(chunks)
if total_chunks >= num_rows:
break
dataframe = pd.DataFrame.from_records(
[(k, v) for k, values in data.items() for v in values],
columns=["filename", "chunks"],
)
col_doc = "chunks"
return (
dataframe,
gr.Dropdown(
choices=["chunks"],
label="Documents column",
value=col_doc,
interactive=(False if col_doc == "" else True),
multiselect=False,
),
)
def generate_system_prompt(dataset_description, progress=gr.Progress()):
progress(0.1, desc="Initializing")
generate_description = get_prompt_generator()
progress(0.5, desc="Generating")
result = next(
generate_description.process(
[
{
"instruction": dataset_description,
}
]
)
)[0]["generation"]
progress(1.0, desc="Prompt generated")
return result
def load_dataset_file(
repo_id: str,
file_paths: list[str],
input_type: str,
num_rows: int = 10,
token: Union[OAuthToken, None] = None,
progress=gr.Progress(),
):
progress(0.1, desc="Loading the source data")
if input_type == "dataset-input":
return _load_dataset_from_hub(repo_id, num_rows, token)
else:
return _preprocess_input_data(file_paths, num_rows)
def generate_dataset(
input_type: str,
dataframe: pd.DataFrame,
system_prompt: str,
document_column: str,
retrieval: bool = False,
reranking: bool = False,
num_rows: int = 10,
temperature: float = 0.7,
is_sample: bool = False,
progress=gr.Progress(),
):
num_rows = test_max_num_rows(num_rows)
progress(0.0, desc="Initializing dataset generation")
if input_type == "prompt-input":
chunk_generator = get_chunks_generator(
temperature=temperature, is_sample=is_sample
)
else:
document_data = column_to_list(dataframe, document_column)
if len(document_data) < num_rows:
document_data += random.choices(
document_data, k=num_rows - len(document_data)
)
retrieval_generator = get_sentence_pair_generator(
action="query",
triplet=True if retrieval else False,
temperature=temperature,
is_sample=is_sample,
)
response_generator = get_response_generator(
temperature=temperature, is_sample=is_sample
)
if reranking:
reranking_generator = get_sentence_pair_generator(
action="semantically-similar",
triplet=True,
temperature=temperature,
is_sample=is_sample,
)
steps = 2 + sum([1 if reranking else 0, 1 if input_type == "prompt-type" else 0])
total_steps: int = num_rows * steps
step_progress = round(1 / steps, 2)
batch_size = DEFAULT_BATCH_SIZE
# generate chunks
if input_type == "prompt-input":
n_processed = 0
chunk_results = []
rewritten_system_prompts = get_rewritten_prompts(system_prompt, num_rows)
while n_processed < num_rows:
progress(
step_progress * n_processed / num_rows,
total=total_steps,
desc="Generating chunks",
)
remaining_rows = num_rows - n_processed
batch_size = min(batch_size, remaining_rows)
inputs = [
{"task": random.choice(rewritten_system_prompts)}
for _ in range(batch_size)
]
chunks = list(chunk_generator.process(inputs=inputs))
chunk_results.extend(chunks[0])
n_processed += batch_size
random.seed(a=random.randint(0, 2**32 - 1))
document_data = [chunk["generation"] for chunk in chunk_results]
progress(step_progress, desc="Generating chunks")
# generate questions
n_processed = 0
retrieval_results = []
while n_processed < num_rows:
progress(
step_progress * n_processed / num_rows,
total=total_steps,
desc="Generating questions",
)
remaining_rows = num_rows - n_processed
batch_size = min(batch_size, remaining_rows)
inputs = [
{"anchor": document}
for document in document_data[n_processed : n_processed + batch_size]
]
questions = list(retrieval_generator.process(inputs=inputs))
retrieval_results.extend(questions[0])
n_processed += batch_size
for result in retrieval_results:
result["context"] = result["anchor"]
if retrieval:
result["question"] = result["positive"]
result["positive_retrieval"] = result.pop("positive")
result["negative_retrieval"] = result.pop("negative")
else:
result["question"] = result.pop("positive")
progress(step_progress, desc="Generating questions")
# generate responses
n_processed = 0
response_results = []
while n_processed < num_rows:
progress(
step_progress + step_progress * n_processed / num_rows,
total=total_steps,
desc="Generating responses",
)
batch = retrieval_results[n_processed : n_processed + batch_size]
responses = list(response_generator.process(inputs=batch))
response_results.extend(responses[0])
n_processed += batch_size
for result in response_results:
result["response"] = result["generation"]
progress(step_progress, desc="Generating responses")
# generate reranking
if reranking:
n_processed = 0
reranking_results = []
while n_processed < num_rows:
progress(
step_progress * n_processed / num_rows,
total=total_steps,
desc="Generating reranking data",
)
batch = response_results[n_processed : n_processed + batch_size]
batch = list(reranking_generator.process(inputs=batch))
reranking_results.extend(batch[0])
n_processed += batch_size
for result in reranking_results:
result["positive_reranking"] = result.pop("positive")
result["negative_reranking"] = result.pop("negative")
progress(
1,
total=total_steps,
desc="Creating dataset",
)
# create distiset
distiset_results = []
source_results = reranking_results if reranking else response_results
base_keys = ["context", "question", "response"]
retrieval_keys = ["positive_retrieval", "negative_retrieval"] if retrieval else []
reranking_keys = ["positive_reranking", "negative_reranking"] if reranking else []
relevant_keys = base_keys + retrieval_keys + reranking_keys
for result in source_results:
record = {key: result.get(key) for key in relevant_keys if key in result}
distiset_results.append(record)
dataframe = pd.DataFrame(distiset_results)
progress(1.0, desc="Dataset generation completed")
return dataframe
def generate_sample_dataset(
repo_id: str,
file_paths: list[str],
input_type: str,
system_prompt: str,
document_column: str,
retrieval_reranking: list[str],
num_rows: str,
oauth_token: Union[OAuthToken, None],
progress=gr.Progress(),
):
retrieval = "Retrieval" in retrieval_reranking
reranking = "Reranking" in retrieval_reranking
if input_type == "prompt-input":
dataframe = pd.DataFrame(columns=["context", "question", "response"])
else:
dataframe, _ = load_dataset_file(
repo_id=repo_id,
file_paths=file_paths,
input_type=input_type,
num_rows=num_rows,
token=oauth_token,
)
progress(0.5, desc="Generating dataset")
dataframe = generate_dataset(
input_type=input_type,
dataframe=dataframe,
system_prompt=system_prompt,
document_column=document_column,
retrieval=retrieval,
reranking=reranking,
num_rows=10,
is_sample=True,
)
return dataframe
def push_dataset_to_hub(
dataframe: pd.DataFrame,
org_name: str,
repo_name: str,
oauth_token: Union[gr.OAuthToken, None],
private: bool,
pipeline_code: str,
progress=gr.Progress(),
):
progress(0.0, desc="Validating")
repo_id = validate_push_to_hub(org_name, repo_name)
progress(0.5, desc="Creating dataset")
dataset = Dataset.from_pandas(dataframe)
dataset = combine_datasets(repo_id, dataset, oauth_token)
distiset = Distiset({"default": dataset})
progress(0.9, desc="Pushing dataset")
distiset.push_to_hub(
repo_id=repo_id,
private=private,
include_script=False,
token=oauth_token.token,
create_pr=False,
)
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
progress(1.0, desc="Dataset pushed")
return dataframe
def push_dataset(
org_name: str,
repo_name: str,
private: bool,
original_repo_id: str,
file_paths: list[str],
input_type: str,
system_prompt: str,
document_column: str,
retrieval_reranking: list[str],
num_rows: int,
temperature: float,
pipeline_code: str,
oauth_token: Union[gr.OAuthToken, None] = None,
progress=gr.Progress(),
) -> pd.DataFrame:
retrieval = "Retrieval" in retrieval_reranking
reranking = "Reranking" in retrieval_reranking
if input_type == "prompt-input":
dataframe = pd.DataFrame(columns=["context", "question", "response"])
else:
dataframe, _ = load_dataset_file(
repo_id=original_repo_id,
file_paths=file_paths,
input_type=input_type,
num_rows=num_rows,
token=oauth_token,
)
progress(0.5, desc="Generating dataset")
dataframe = generate_dataset(
input_type=input_type,
dataframe=dataframe,
system_prompt=system_prompt,
document_column=document_column,
retrieval=retrieval,
reranking=reranking,
num_rows=num_rows,
temperature=temperature,
is_sample=True,
)
push_dataset_to_hub(
dataframe, org_name, repo_name, oauth_token, private, pipeline_code
)
dataframe = dataframe[
dataframe.applymap(
lambda x: str(x).strip() if pd.notna(x) else x
).apply(
lambda row: row.notna().all() and (row != "").all(), axis=1
)
]
try:
progress(0.1, desc="Setting up user and workspace")
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
client = get_argilla_client()
if client is None:
return ""
progress(0.5, desc="Creating dataset in Argilla")
fields = [
rg.TextField(
name="context",
title="Context",
description="Context for the generation",
),
rg.ChatField(
name="chat",
title="Chat",
description="User and assistant conversation based on the context",
),
]
for item in ["positive", "negative"]:
if retrieval:
fields.append(
rg.TextField(
name=f"{item}_retrieval",
title=f"{item.capitalize()} retrieval",
description=f"The {item} query for retrieval",
)
)
if reranking:
fields.append(
rg.TextField(
name=f"{item}_reranking",
title=f"{item.capitalize()} reranking",
description=f"The {item} query for reranking",
)
)
questions = [
rg.LabelQuestion(
name="relevant",
title="Are the question and response relevant to the given context?",
labels=["yes", "no"],
),
rg.LabelQuestion(
name="is_response_correct",
title="Is the response correct?",
labels=["yes", "no"],
),
]
for item in ["positive", "negative"]:
if retrieval:
questions.append(
rg.LabelQuestion(
name=f"is_{item}_retrieval_relevant",
title=f"Is the {item} retrieval relevant?",
labels=["yes", "no"],
required=False,
)
)
if reranking:
questions.append(
rg.LabelQuestion(
name=f"is_{item}_reranking_relevant",
title=f"Is the {item} reranking relevant?",
labels=["yes", "no"],
required=False,
)
)
metadata = [
rg.IntegerMetadataProperty(
name=f"{item}_length", title=f"{item.capitalize()} length"
)
for item in ["context", "question", "response"]
]
vectors = [
rg.VectorField(
name=f"{item}_embeddings",
dimensions=get_sentence_embedding_dimensions(),
)
for item in ["context", "question", "response"]
]
settings = rg.Settings(
fields=fields,
questions=questions,
metadata=metadata,
vectors=vectors,
guidelines="Please review the conversation and provide an evaluation.",
)
dataframe["chat"] = dataframe.apply(
lambda row: [
{"role": "user", "content": row["question"]},
{"role": "assistant", "content": row["response"]},
],
axis=1,
)
for item in ["context", "question", "response"]:
dataframe[f"{item}_length"] = dataframe[item].apply(
lambda x: len(x) if x is not None else 0
)
dataframe[f"{item}_embeddings"] = get_embeddings(
dataframe[item].apply(lambda x: x if x is not None else "").to_list()
)
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
if rg_dataset is None:
rg_dataset = rg.Dataset(
name=repo_name,
workspace=hf_user,
settings=settings,
client=client,
)
rg_dataset = rg_dataset.create()
progress(0.7, desc="Pushing dataset to Argilla")
hf_dataset = Dataset.from_pandas(dataframe)
rg_dataset.records.log(records=hf_dataset)
progress(1.0, desc="Dataset pushed to Argilla")
except Exception as e:
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
return ""
def show_system_prompt_visibility():
return {system_prompt: gr.Textbox(visible=True)}
def hide_system_prompt_visibility():
return {system_prompt: gr.Textbox(visible=False)}
def show_document_column_visibility():
return {document_column: gr.Dropdown(visible=True)}
def hide_document_column_visibility():
return {
document_column: gr.Dropdown(
choices=["Load your data first in step 1."],
value="Load your data first in step 1.",
visible=False,
)
}
def show_pipeline_code_visibility():
return {pipeline_code_ui: gr.Accordion(visible=True)}
def hide_pipeline_code_visibility():
return {pipeline_code_ui: gr.Accordion(visible=False)}
######################
# Gradio UI
######################
with gr.Blocks() as app:
with gr.Column() as main_ui:
gr.Markdown("## 1. Select your input")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
input_type = gr.Dropdown(
label="Input type",
choices=["dataset-input", "file-input", "prompt-input"],
value="dataset-input",
multiselect=False,
visible=False,
)
with gr.Tab("Load from Hub") as tab_dataset_input:
with gr.Row(equal_height=False):
with gr.Column(scale=2):
search_in = HuggingfaceHubSearch(
label="Search",
placeholder="Search for a dataset",
search_type="dataset",
sumbit_on_select=True,
)
with gr.Row():
clear_dataset_btn_part = gr.Button(
"Clear", variant="secondary"
)
load_dataset_btn = gr.Button("Load", variant="primary")
with gr.Column(scale=3):
examples = gr.Examples(
examples=[
"charris/wikipedia_sample",
"plaguss/argilla_sdk_docs_raw_unstructured",
"BeIR/hotpotqa-generated-queries",
],
label="Example datasets",
fn=lambda x: x,
inputs=[search_in],
run_on_click=True,
)
search_out = gr.HTML(label="Dataset preview", visible=False)
with gr.Tab("Load your file") as tab_file_input:
with gr.Row(equal_height=False):
with gr.Column(scale=2):
file_in = gr.File(
label="Upload your file. Supported formats: .md, .txt, .docx, .pdf",
file_count="multiple",
file_types=[".md", ".txt", ".docx", ".pdf"],
)
with gr.Row():
clear_file_btn_part = gr.Button(
"Clear", variant="secondary"
)
load_file_btn = gr.Button("Load", variant="primary")
with gr.Column(scale=3):
file_out = gr.HTML(label="Dataset preview", visible=False)
with gr.Tab("Generate from prompt") as tab_prompt_input:
with gr.Row(equal_height=False):
with gr.Column(scale=2):
dataset_description = gr.Textbox(
label="Dataset description",
placeholder="Give a precise description of your desired dataset.",
)
with gr.Row():
clear_prompt_btn_part = gr.Button(
"Clear", variant="secondary"
)
load_prompt_btn = gr.Button("Create", variant="primary")
with gr.Column(scale=3):
examples = gr.Examples(
examples=DEFAULT_DATASET_DESCRIPTIONS,
inputs=[dataset_description],
cache_examples=False,
label="Examples",
)
gr.HTML(value="<hr>")
gr.Markdown(value="## 2. Configure your task")
with gr.Row(equal_height=True):
with gr.Row(equal_height=False):
with gr.Column(scale=2):
system_prompt = gr.Textbox(
label="System prompt",
placeholder="You are a helpful assistant.",
visible=False,
)
document_column = gr.Dropdown(
label="Document Column",
info="Select the document column to generate the RAG dataset",
choices=["Load your data first in step 1."],
value="Load your data first in step 1.",
interactive=False,
multiselect=False,
allow_custom_value=False,
)
retrieval_reranking = gr.CheckboxGroup(
choices=[("Retrieval", "Retrieval"), ("Reranking", "Reranking")],
type="value",
label="Data for RAG",
info="Indicate the additional data you want to generate for RAG.",
)
with gr.Row():
clear_btn_full = gr.Button("Clear", variant="secondary")
btn_apply_to_sample_dataset = gr.Button(
"Save", variant="primary"
)
with gr.Column(scale=3):
dataframe = gr.Dataframe(
headers=["context", "question", "response"],
wrap=True,
interactive=False,
)
gr.HTML(value="<hr>")
gr.Markdown(value="## 3. Generate your dataset")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
org_name = get_org_dropdown()
repo_name = gr.Textbox(
label="Repo name",
placeholder="dataset_name",
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
interactive=True,
)
num_rows = gr.Number(
label="Number of rows",
value=10,
interactive=True,
scale=1,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1,
value=0.7,
step=0.1,
interactive=True,
)
private = gr.Checkbox(
label="Private dataset",
value=False,
interactive=True,
scale=1,
)
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
with gr.Column(scale=3):
success_message = gr.Markdown(
visible=True,
min_height=100, # don't remove this otherwise progress is not visible
)
with gr.Accordion(
"Customize your pipeline with distilabel",
open=False,
visible=False,
) as pipeline_code_ui:
code = generate_pipeline_code(
repo_id=search_in.value,
file_paths=file_in.value,
input_type=input_type.value,
system_prompt=system_prompt.value,
document_column=document_column.value,
retrieval_reranking=retrieval_reranking.value,
num_rows=num_rows.value,
)
pipeline_code = gr.Code(
value=code,
language="python",
label="Distilabel Pipeline Code",
)
tab_dataset_input.select(
fn=lambda: "dataset-input",
inputs=[],
outputs=[input_type],
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
fn=show_document_column_visibility, inputs=[], outputs=[document_column]
)
tab_file_input.select(
fn=lambda: "file-input",
inputs=[],
outputs=[input_type],
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
fn=show_document_column_visibility, inputs=[], outputs=[document_column]
)
tab_prompt_input.select(
fn=lambda: "prompt-input",
inputs=[],
outputs=[input_type],
).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
)
search_in.submit(fn=get_iframe, inputs=search_in, outputs=search_out).then(
fn=lambda df: pd.DataFrame(columns=df.columns),
inputs=[dataframe],
outputs=[dataframe],
)
load_dataset_btn.click(
fn=load_dataset_file,
inputs=[search_in, file_in, input_type],
outputs=[
dataframe,
document_column,
],
)
load_file_btn.click(
fn=load_dataset_file,
inputs=[search_in, file_in, input_type],
outputs=[
dataframe,
document_column,
],
)
load_prompt_btn.click(
fn=generate_system_prompt,
inputs=[dataset_description],
outputs=[system_prompt],
show_progress=True,
).success(
fn=generate_sample_dataset,
inputs=[
search_in,
file_in,
input_type,
system_prompt,
document_column,
retrieval_reranking,
num_rows,
],
outputs=dataframe,
)
btn_apply_to_sample_dataset.click(
fn=generate_sample_dataset,
inputs=[
search_in,
file_in,
input_type,
system_prompt,
document_column,
retrieval_reranking,
num_rows,
],
outputs=dataframe,
)
btn_push_to_hub.click(
fn=validate_argilla_user_workspace_dataset,
inputs=[repo_name],
outputs=[success_message],
show_progress=True,
).then(
fn=validate_push_to_hub,
inputs=[org_name, repo_name],
outputs=[success_message],
show_progress=True,
).success(
fn=hide_success_message,
outputs=[success_message],
show_progress=True,
).success(
fn=hide_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
).success(
fn=push_dataset,
inputs=[
org_name,
repo_name,
private,
search_in,
file_in,
input_type,
system_prompt,
document_column,
retrieval_reranking,
num_rows,
temperature,
pipeline_code,
],
outputs=[success_message],
show_progress=True,
).success(
fn=show_success_message,
inputs=[org_name, repo_name],
outputs=[success_message],
).success(
fn=generate_pipeline_code,
inputs=[
search_in,
file_in,
input_type,
system_prompt,
document_column,
retrieval_reranking,
num_rows,
],
outputs=[pipeline_code],
).success(
fn=show_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
)
clear_dataset_btn_part.click(fn=lambda : "", inputs=[], outputs=[search_in])
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
clear_prompt_btn_part.click(
fn=lambda : "", inputs=[], outputs=[dataset_description]
)
clear_btn_full.click(
fn=lambda df: ("", [], pd.DataFrame(columns=df.columns)),
inputs=[dataframe],
outputs=[document_column, retrieval_reranking, dataframe],
)
app.load(fn=swap_visibility, outputs=main_ui)
app.load(fn=get_org_dropdown, outputs=[org_name])
app.load(fn=get_random_repo_name, outputs=[repo_name])