sdiazlor's picture
feat: Address edge cases and improve textcat UI
6a8a817
raw
history blame
17 kB
import re
from typing import List, Union
import argilla as rg
import gradio as gr
import pandas as pd
from datasets import Dataset
from huggingface_hub import HfApi
from src.distilabel_dataset_generator.apps.base import (
get_argilla_client,
get_main_ui,
get_pipeline_code_ui,
hide_success_message,
push_pipeline_code_to_hub,
show_success_message_argilla,
show_success_message_hub,
validate_argilla_user_workspace_dataset,
)
from src.distilabel_dataset_generator.apps.base import (
push_dataset_to_hub as push_to_hub_base,
)
from src.distilabel_dataset_generator.pipelines.base import (
DEFAULT_BATCH_SIZE,
)
from src.distilabel_dataset_generator.pipelines.embeddings import (
get_embeddings,
get_sentence_embedding_dimensions,
)
from src.distilabel_dataset_generator.pipelines.textcat import (
DEFAULT_DATASET_DESCRIPTIONS,
DEFAULT_DATASETS,
DEFAULT_SYSTEM_PROMPTS,
PROMPT_CREATION_PROMPT,
generate_pipeline_code,
get_labeller_generator,
get_prompt_generator,
get_textcat_generator,
)
TASK = "text_classification"
def push_dataset_to_hub(
dataframe: pd.DataFrame,
private: bool = True,
org_name: str = None,
repo_name: str = None,
oauth_token: Union[gr.OAuthToken, None] = None,
progress=gr.Progress(),
labels: List[str] = None,
num_labels: int = 1,
):
original_dataframe = dataframe.copy(deep=True)
try:
push_to_hub_base(
dataframe,
private,
org_name,
repo_name,
oauth_token,
progress,
labels,
num_labels,
task=TASK,
)
except Exception as e:
raise gr.Error(f"Error pushing dataset to the Hub: {e}")
return original_dataframe
def push_dataset_to_argilla(
dataframe: pd.DataFrame,
dataset_name: str,
oauth_token: Union[gr.OAuthToken, None] = None,
progress=gr.Progress(),
num_labels: int = 1,
labels: List[str] = None,
) -> pd.DataFrame:
original_dataframe = dataframe.copy(deep=True)
try:
progress(0.1, desc="Setting up user and workspace")
client = get_argilla_client()
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
labels = [label.lower().strip() for label in labels]
settings = rg.Settings(
fields=[
rg.TextField(
name="text",
description="The text classification data",
title="Text",
),
],
questions=[
(
rg.LabelQuestion(
name="label",
title="Label",
description="The label of the text",
labels=labels,
)
if num_labels == 1
else rg.MultiLabelQuestion(
name="labels",
title="Labels",
description="The labels of the conversation",
labels=labels,
)
),
],
metadata=[
rg.IntegerMetadataProperty(name="text_length", title="Text Length"),
],
vectors=[
rg.VectorField(
name="text_embeddings",
dimensions=get_sentence_embedding_dimensions(),
)
],
guidelines="Please review the text and provide or correct the label where needed.",
)
dataframe["text_length"] = dataframe["text"].apply(len)
dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
progress(0.5, desc="Creating dataset")
rg_dataset = client.datasets(name=dataset_name, workspace=hf_user)
if rg_dataset is None:
rg_dataset = rg.Dataset(
name=dataset_name,
workspace=hf_user,
settings=settings,
client=client,
)
rg_dataset = rg_dataset.create()
progress(0.7, desc="Pushing dataset to Argilla")
hf_dataset = Dataset.from_pandas(dataframe)
records = [
rg.Record(
fields={
"text": sample["text"],
},
metadata={"text_length": sample["text_length"]},
vectors={"text_embeddings": sample["text_embeddings"]},
suggestions=(
[
rg.Suggestion(
question_name="label" if num_labels == 1 else "labels",
value=(
sample["label"] if num_labels == 1 else sample["labels"]
),
)
]
if (
(num_labels == 1 and sample["label"] in labels)
or (
num_labels > 1
and all(label in labels for label in sample["labels"])
)
)
else []
),
)
for sample in hf_dataset
]
rg_dataset.records.log(records=records)
progress(1.0, desc="Dataset pushed to Argilla")
except Exception as e:
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
return original_dataframe
def generate_system_prompt(dataset_description, progress=gr.Progress()):
progress(0.0, desc="Generating text classification task")
if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
if index < len(DEFAULT_SYSTEM_PROMPTS):
return DEFAULT_SYSTEM_PROMPTS[index]
progress(0.3, desc="Initializing text generation")
generate_description = get_prompt_generator()
progress(0.7, desc="Generating text classification task")
result = next(
generate_description.process(
[
{
"system_prompt": PROMPT_CREATION_PROMPT,
"instruction": dataset_description,
}
]
)
)[0]["generation"]
progress(1.0, desc="Text classification task generated")
return result
def generate_dataset(
system_prompt: str,
difficulty: str,
clarity: str,
labels: List[str] = None,
num_labels: int = 1,
num_rows: int = 10,
is_sample: bool = False,
progress=gr.Progress(),
) -> pd.DataFrame:
progress(0.0, desc="(1/2) Generating text classification data")
textcat_generator = get_textcat_generator(
difficulty=difficulty, clarity=clarity, is_sample=is_sample
)
labeller_generator = get_labeller_generator(
system_prompt=system_prompt,
labels=labels,
num_labels=num_labels,
is_sample=is_sample,
)
total_steps: int = num_rows * 2
batch_size = DEFAULT_BATCH_SIZE
# create text classification data
n_processed = 0
textcat_results = []
while n_processed < num_rows:
progress(
0.5 * n_processed / num_rows,
total=total_steps,
desc="(1/2) Generating text classification data",
)
remaining_rows = num_rows - n_processed
batch_size = min(batch_size, remaining_rows)
inputs = [{"task": system_prompt} for _ in range(batch_size)]
batch = list(textcat_generator.process(inputs=inputs))
textcat_results.extend(batch[0])
n_processed += batch_size
for result in textcat_results:
result["text"] = result["input_text"]
# label text classification data
progress(0.5, desc="(1/2) Generating text classification data")
if not is_sample:
n_processed = 0
labeller_results = []
while n_processed < num_rows:
progress(
0.5 + 0.5 * n_processed / num_rows,
total=total_steps,
desc="(1/2) Labeling text classification data",
)
batch = textcat_results[n_processed : n_processed + batch_size]
labels = list(labeller_generator.process(inputs=batch))
labeller_results.extend(labels[0])
n_processed += batch_size
progress(
1,
total=total_steps,
desc="(2/2) Creating dataset",
)
# create final dataset
distiset_results = []
source_results = textcat_results if is_sample else labeller_results
for result in source_results:
record = {
key: result[key]
for key in ["text", "label" if is_sample else "labels"]
if key in result
}
distiset_results.append(record)
dataframe = pd.DataFrame(distiset_results)
if num_labels == 1:
dataframe = dataframe.rename(columns={"labels": "label"})
progress(1.0, desc="Dataset generation completed")
return dataframe
def update_suggested_labels(system_prompt):
new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
if not new_labels:
return gr.Warning(
"No labels found in the system prompt. Please add labels manually."
)
return gr.update(choices=new_labels, value=new_labels)
def validate_input_labels(labels):
if not labels or len(labels) < 2:
raise gr.Error(
f"Please select at least 2 labels to classify your text. You selected {len(labels) if labels else 0}."
)
return labels
(
app,
main_ui,
custom_input_ui,
dataset_description,
examples,
btn_generate_system_prompt,
system_prompt,
sample_dataset,
btn_generate_sample_dataset,
dataset_name,
add_to_existing_dataset,
btn_generate_full_dataset_argilla,
btn_generate_and_push_to_argilla,
btn_push_to_argilla,
org_name,
repo_name,
private,
btn_generate_full_dataset,
btn_generate_and_push_to_hub,
btn_push_to_hub,
final_dataset,
success_message,
) = get_main_ui(
default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS,
default_system_prompts=DEFAULT_SYSTEM_PROMPTS,
default_datasets=DEFAULT_DATASETS,
fn_generate_system_prompt=generate_system_prompt,
fn_generate_dataset=generate_dataset,
task=TASK,
)
with app:
with main_ui:
with custom_input_ui:
difficulty = gr.Dropdown(
choices=[
("High School", "high school"),
("College", "college"),
("PhD", "PhD"),
("Mixed", "mixed"),
],
value="mixed",
label="Difficulty",
info="The difficulty of the text to be generated.",
)
clarity = gr.Dropdown(
choices=[
("Clear", "CLEAR"),
(
"Understandable",
"understandable with some effort",
),
("Ambiguous", "ambiguous"),
("Mixed", "mixed"),
],
value="mixed",
label="Clarity",
info="The clarity of the text to be generated.",
)
with gr.Column():
labels = gr.Dropdown(
choices=[],
allow_custom_value=True,
interactive=True,
label="Labels",
multiselect=True,
info="Add the labels to classify the text.",
)
with gr.Blocks():
btn_suggested_labels = gr.Button(
value="Add suggested labels",
size="sm",
)
num_labels = gr.Number(
label="Number of labels",
value=1,
minimum=1,
maximum=10,
info="The number of labels to classify the text.",
)
num_rows = gr.Number(
label="Number of rows",
value=10,
minimum=1,
maximum=500,
info="More rows will take longer to generate.",
)
pipeline_code = get_pipeline_code_ui(
generate_pipeline_code(
system_prompt.value,
difficulty=difficulty.value,
clarity=clarity.value,
labels=labels.value,
num_labels=num_labels.value,
num_rows=num_rows.value,
)
)
# define app triggers
btn_suggested_labels.click(
fn=update_suggested_labels,
inputs=[system_prompt],
outputs=labels,
)
gr.on(
triggers=[
btn_generate_full_dataset.click,
btn_generate_full_dataset_argilla.click,
],
fn=hide_success_message,
outputs=[success_message],
).then(
fn=validate_input_labels,
inputs=[labels],
outputs=[labels],
).success(
fn=generate_dataset,
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
outputs=[final_dataset],
show_progress=True,
)
btn_generate_and_push_to_argilla.click(
fn=validate_argilla_user_workspace_dataset,
inputs=[dataset_name, final_dataset, add_to_existing_dataset],
outputs=[final_dataset],
show_progress=True,
).success(
fn=hide_success_message,
outputs=[success_message],
).success(
fn=generate_dataset,
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
outputs=[final_dataset],
show_progress=True,
).success(
fn=push_dataset_to_argilla,
inputs=[final_dataset, dataset_name, num_labels, labels],
outputs=[final_dataset],
show_progress=True,
).success(
fn=show_success_message_argilla,
inputs=[],
outputs=[success_message],
)
btn_generate_and_push_to_hub.click(
fn=hide_success_message,
outputs=[success_message],
).then(
fn=generate_dataset,
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
outputs=[final_dataset],
show_progress=True,
).then(
fn=push_dataset_to_hub,
inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
outputs=[final_dataset],
show_progress=True,
).then(
fn=push_pipeline_code_to_hub,
inputs=[pipeline_code, org_name, repo_name],
outputs=[],
show_progress=True,
).success(
fn=show_success_message_hub,
inputs=[org_name, repo_name],
outputs=[success_message],
)
btn_push_to_hub.click(
fn=hide_success_message,
outputs=[success_message],
).then(
fn=push_dataset_to_hub,
inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
outputs=[final_dataset],
show_progress=True,
).then(
fn=push_pipeline_code_to_hub,
inputs=[pipeline_code, org_name, repo_name],
outputs=[],
show_progress=True,
).success(
fn=show_success_message_hub,
inputs=[org_name, repo_name],
outputs=[success_message],
)
btn_push_to_argilla.click(
fn=hide_success_message,
outputs=[success_message],
).success(
fn=validate_argilla_user_workspace_dataset,
inputs=[dataset_name, final_dataset, add_to_existing_dataset],
outputs=[final_dataset],
show_progress=True,
).success(
fn=push_dataset_to_argilla,
inputs=[final_dataset, dataset_name, num_labels, labels],
outputs=[final_dataset],
show_progress=True,
).success(
fn=show_success_message_argilla,
inputs=[],
outputs=[success_message],
)
system_prompt.change(
fn=generate_pipeline_code,
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
outputs=[pipeline_code],
)
difficulty.change(
fn=generate_pipeline_code,
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
outputs=[pipeline_code],
)
clarity.change(
fn=generate_pipeline_code,
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
outputs=[pipeline_code],
)
labels.change(
fn=generate_pipeline_code,
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
outputs=[pipeline_code],
)
num_labels.change(
fn=generate_pipeline_code,
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
outputs=[pipeline_code],
)