chatbot-arena-dataset-wrapper / src /lmsys_dataset_handler.py
reddgr's picture
first commit
87712ac
raw
history blame
12.4 kB
import os
import pandas as pd
import textwrap
import random
from datasets import load_dataset
from IPython.display import display
class LMSYSChat1MHandler:
def __init__(self, hf_token, streaming=False, verbose=True):
self.hf_token = hf_token
self.streaming = streaming
self.lmsys_dataset = load_dataset(
'lmsys/lmsys-chat-1m',
revision="main",
token=self.hf_token,
streaming=self.streaming
)
self.verbose = verbose
if verbose:
print(self.lmsys_dataset)
self.df_sample = None
self.df_prompts = None
self.unwrapped_turns_df = None
if not self.streaming and verbose:
print('Data is cached at:\n')
for file_info in self.lmsys_dataset['train'].cache_files:
filename = file_info['filename']
file_size = os.path.getsize(filename)
i = int((len(filename) - 41) / 2)
print(f"Filename: {filename[:i]}*{filename[-41:]}\nSize: {file_size} bytes")
def extract_df_sample(self, n_samples=None, conversation_ids=None):
"""
Extracts a sample of conversations or specific conversations based on their conversation IDs.
Parameters:
- n_samples (int): Number of random samples to extract. Ignored if `conversation_ids` is provided.
- conversation_ids (list): List of conversation IDs to extract. If provided, this takes precedence over `n_samples`.
Returns:
- pd.DataFrame: A DataFrame containing the extracted conversations.
"""
if conversation_ids:
# Filter conversations based on the provided conversation IDs
df_sample = self.lmsys_dataset['train'].to_pandas()
df_sample = df_sample[df_sample['conversation_id'].isin(conversation_ids)]
print(f"Retrieved {len(df_sample)} conversations based on specified IDs")
else:
# Randomly sample conversations if no IDs are provided
if not self.streaming:
df_sample = self.lmsys_dataset['train'].to_pandas().sample(n_samples)
print(f"Retrieved {len(df_sample)} random conversations from lmsys/lmsys-chat-1m")
else:
# Take a sample from the streamed dataset
streamed_samples = []
for i, row in enumerate(self.lmsys_dataset['train']):
streamed_samples.append(row)
if i + 1 == n_samples: # Collect only the desired number of samples
break
# Shuffle and convert the collected samples to a Pandas DataFrame
random.shuffle(streamed_samples)
df_sample = pd.DataFrame(streamed_samples)
self.df_sample = df_sample
if self.verbose and len(df_sample) > 4:
display(df_sample.head(2))
print('...')
display(df_sample.tail(2))
return df_sample
def parquet_sampling(self, n_samples):
base_url = "https://huggingface.co/datasets/lmsys/lmsys-chat-1m/resolve/main/data/"
data_files = [
"train-00000-of-00006-4feeb3f83346a0e9.parquet",
"train-00001-of-00006-4030672591c2f478.parquet",
"train-00002-of-00006-1779b7cec9462180.parquet",
"train-00003-of-00006-2fa862bfed56af1f.parquet",
"train-00004-of-00006-18f4bdd50c103e71.parquet",
"train-00005-of-00006-fe1acc5d10a9f0e2.parquet"
]
sample_file = random.choice(data_files)
print(f"Sampling from {sample_file}")
data_files = {"train": base_url + sample_file}
parquet_sample = load_dataset("parquet", data_files=data_files, split="train")
df_sample = parquet_sample.to_pandas().sample(n_samples)
print(f"Retrieved {len(df_sample)} random conversations from lmsys/lmsys-chat-1m/{sample_file}")
self.df_sample = df_sample
if self.verbose and len(df_sample) > 4:
display(df_sample.head(2))
print('...')
display(df_sample.tail(2))
return df_sample
def add_turns_to_conversations(self):
"""
Adds 'turn' keys to each conversation in the 'conversation' column of the dataframe.
"""
self.df_sample['conversation'] = self.df_sample['conversation'].apply(
lambda conv: Conversation(conv).add_turns()
)
df_with_turns = self.df_sample
return df_with_turns
def unwrap_turns(self):
"""
Creates a dataframe where each row corresponds to a pair of user-assistant messages in a conversation and turn.
The 'prompt' column contains the user's message, and the 'response' column contains the assistant's message.
Each row includes a 'turn_id' column, which numbers the turns uniquely per conversation.
"""
paired_data = []
for _, row in self.df_sample.iterrows():
conversation_id = row['conversation_id']
row_data = row.to_dict()
row_data.pop('conversation') # Remove the 'conversation' field as it's being unwrapped
current_prompt = None
turn_id = None
for message in row['conversation']:
if message['role'] == 'user':
current_prompt = message['content']
turn_id = f"{conversation_id}{message['turn']:03}" # Create turn_id
elif message['role'] == 'assistant' and current_prompt is not None:
# Create a new row with the user-assistant pair
paired_row = {
**row_data,
'turn_n': message['turn'],
'prompt': current_prompt,
'response': message['content'],
}
paired_data.append(paired_row)
current_prompt = None # Reset after pairing
unwrapped_turns_df = pd.DataFrame(paired_data)
unwrapped_turns_df.rename(columns={"turn": "conversation_turns"}, inplace=True) # The naming in the original dataset is ambiguous
self.unwrapped_turns_df = unwrapped_turns_df
return unwrapped_turns_df
def extract_prompts(self, filter_language=None, min_char_length=20, max_char_length=500, exclusions=None):
"""
Extracts user prompts from the sample dataframe, optionally filtering by language and limiting the character length.
Parameters:
- filter_language (list of str or None): A list of specific languages to filter prompts by. If None, no language
filter is applied. Examples of valid values include ['English'], ['English', 'Portuguese'], or
['Spanish', 'French', 'German'].
- min_char_length (int): The minimum character length for user prompts to include. Defaults to 20.
- max_char_length (int): The maximum character length for user prompts to include. Defaults to 500.
- exclusions (str or None): Path to a text file containing phrases. Prompts containing any of these phrases
will be excluded from the results. If None, no exclusions are applied.
Returns:
- pd.DataFrame: A DataFrame containing extracted prompts with columns 'prompt' and 'language'.
"""
df_sample = self.df_sample
if filter_language:
extracted_data = df_sample[df_sample['language'].isin(filter_language)].apply(
lambda row: [
{'content': entry['content'], 'language': row['language']}
for entry in row['conversation']
if entry['role'] == 'user' and min_char_length <= len(entry['content']) <= max_char_length
], axis=1
).explode().dropna()
else:
extracted_data = df_sample.apply(
lambda row: [
{'content': entry['content'], 'language': row['language']}
for entry in row['conversation']
if entry['role'] == 'user' and min_char_length <= len(entry['content']) <= max_char_length
], axis=1
).explode().dropna()
df_prompts = pd.DataFrame(extracted_data.tolist())
df_prompts.rename(columns={'content': 'prompt'}, inplace=True)
orig_length = len(df_prompts)
if exclusions:
# Excluding prompts with phrases that are repeated often in this dataset
with open(exclusions, 'r') as f:
exclusions = [line.strip() for line in f.readlines()]
df_prompts = df_prompts[~df_prompts['prompt'].apply(lambda x: any(exclusion in x for exclusion in exclusions))]
print(f"Excluded {orig_length - len(df_prompts)} prompts.")
self.df_prompts = df_prompts
if self.verbose and len(df_sample) > 4:
display(df_prompts.head(2))
print('...')
display(df_prompts.tail(2))
return df_prompts
def extract_prompt_sample(self):
prompt_sample = self.df_prompts.sample(1)['prompt'].values[0]
if self.verbose:
wrapped_message = textwrap.fill(prompt_sample, width=120)
print(wrapped_message)
return prompt_sample
def search_conversations(self, search_term):
"""
Searches the dataset for a given string and returns a DataFrame with matching records.
Parameters:
- search_term (str): The string to search for in the dataset.
Returns:
- pd.DataFrame: A DataFrame containing conversations where the search term is found.
"""
if self.streaming:
raise ValueError("Search is not supported in streaming mode.")
df = self.lmsys_dataset['train'].to_pandas()
# Filter rows where the search term appears in the 'conversation' column
matching_records = df[df['conversation'].apply(
lambda conv: any(search_term.lower() in message['content'].lower() for message in conv)
)]
if self.verbose:
print(f"Found {len(matching_records)} matching conversations for search term: '{search_term}'")
return matching_records
def print_language_counts(self, df):
language_counts = df['language'].value_counts()
print("Language Record Counts:")
print(language_counts.to_frame('Count').reset_index().rename(columns={'index': 'Language'}))
class Conversation:
def __init__(self, conversation_data):
"""
Initializes the Conversation object with the conversation data.
Parameters:
- conversation_data (list): A list of dictionaries representing a conversation.
"""
self.conversation_data = conversation_data
def add_turns(self):
"""
Adds a 'turn' key to each dictionary in the conversation,
identifying the turn (pair of user and assistant messages).
Returns:
- list: The updated conversation with 'turn' keys added.
"""
turn_counter = 0
for message in self.conversation_data:
if message['role'] == 'user':
turn_counter += 1
message['turn'] = turn_counter
return self.conversation_data
def pretty_print(self, user_prefix, assistant_prefix, width=80):
"""
Prints the conversation with specified prefixes and wrapped text.
Parameters:
- user_prefix (str): Prefix to prepend to user messages.
- assistant_prefix (str): Prefix to prepend to assistant messages.
- width (int): Maximum characters per line for wrapping.
"""
wrapper = textwrap.TextWrapper(width=width)
for message in self.conversation_data:
if message['role'] == 'user':
prefix = user_prefix
elif message['role'] == 'assistant':
prefix = assistant_prefix
else:
continue # Ignore roles other than 'user' and 'assistant'
# Split on existing newlines, wrap each line, and join back with newlines
wrapped_content = "\n".join(
wrapper.fill(line) for line in message['content'].splitlines()
)
print(f"{prefix} {wrapped_content}\n")