reddgr commited on
Commit
87712ac
·
1 Parent(s): 52906e2

first commit

Browse files
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ streamlit-aggrid
3
+ duckdb
4
+ pandas
5
+ requests
6
+ json
7
+ torch
8
+ transformers
9
+ textwrap
src/__pycache__/env_options.cpython-311.pyc ADDED
Binary file (4.08 kB). View file
 
src/__pycache__/lmsys_dataset_handler.cpython-311.pyc ADDED
Binary file (18.4 kB). View file
 
src/__pycache__/lmsys_dataset_wrapper.cpython-311.pyc ADDED
Binary file (20.4 kB). View file
 
src/__pycache__/text_classification_functions.cpython-311.pyc ADDED
Binary file (26.4 kB). View file
 
src/env_options.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import os
4
+ import torch
5
+ import transformers
6
+
7
+ def check_env(colab:bool=False, use_dotenv:bool=True, dotenv_path:str=None, colab_secrets:dict=None) -> tuple:
8
+ # Checking versions and GPU availability:
9
+ print(f"Python version: {sys.version}")
10
+ print(f"PyTorch version: {torch.__version__}")
11
+ print(f"Transformers version: {transformers.__version__}")
12
+ if torch.cuda.is_available():
13
+ print(f"CUDA device: {torch.cuda.get_device_name(0)}")
14
+ print(f"CUDA Version: {torch.version.cuda}")
15
+ print(f"FlashAttention available: {torch.backends.cuda.flash_sdp_enabled()}")
16
+ else:
17
+ print("No CUDA device available")
18
+
19
+ if use_dotenv:
20
+ print("Retrieved token(s) from .env file")
21
+ from dotenv import load_dotenv
22
+ load_dotenv(dotenv_path) # path to your dotenv file
23
+ hf_token = os.getenv("HF_TOKEN")
24
+ hf_token_write = os.getenv("HF_TOKEN_WRITE") # Only used for updating the Reddgr dataset (privileges needed)
25
+ openai_api_key = openai_api_key = os.getenv("OPENAI_API_KEY")
26
+ elif colab:
27
+ hf_token = colab_secrets.get('HF_TOKEN')
28
+ hf_token_write = colab_secrets.get('HF_TOKEN_WRITE')
29
+ openai_api_key = colab_secrets.get("OPENAI_API_KEY")
30
+ else:
31
+ print("Retrieved HuggingFace token(s) from environment variables")
32
+ hf_token = os.environ.get("HF_TOKEN")
33
+ hf_token_write = os.environ.get("HF_TOKEN_WRITE")
34
+ openai_api_key = openai_api_key = os.getenv("OPENAI_API_KEY")
35
+
36
+ def mask_token(token, unmasked_chars=4):
37
+ return token[:unmasked_chars] + '*' * (len(token) - unmasked_chars*2) + token[-unmasked_chars:]
38
+
39
+ if hf_token is None:
40
+ print("HF_TOKEN not found in the provided .env file" if use_dotenv else "HF_TOKEN not found in the environment variables")
41
+ if hf_token_write is None:
42
+ print("HF_TOKEN_WRITE not found in the provided .env file" if use_dotenv else "HF_TOKEN_WRITE not found in the environment variables")
43
+ if openai_api_key is None:
44
+ print("OPENAI_API_KEY not found in the provided .env file" if use_dotenv else "OPENAI_API_KEY not found in the environment variables")
45
+
46
+ masked_hf_token = mask_token(hf_token) if hf_token else None
47
+ masked_hf_token_write = mask_token(hf_token_write) if hf_token_write else None
48
+ masked_openai_api_key = mask_token(openai_api_key) if openai_api_key else None
49
+
50
+ if masked_hf_token:
51
+ print(f"Using HuggingFace token: {masked_hf_token}")
52
+ if masked_hf_token_write:
53
+ print(f"Using HuggingFace write token: {masked_hf_token_write}")
54
+ if masked_openai_api_key:
55
+ print(f"Using OpenAI token: {masked_openai_api_key}")
56
+
57
+ return hf_token, hf_token_write, openai_api_key
src/lmsys_dataset_handler.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import textwrap
4
+ import random
5
+ from datasets import load_dataset
6
+ from IPython.display import display
7
+
8
+ class LMSYSChat1MHandler:
9
+ def __init__(self, hf_token, streaming=False, verbose=True):
10
+ self.hf_token = hf_token
11
+ self.streaming = streaming
12
+ self.lmsys_dataset = load_dataset(
13
+ 'lmsys/lmsys-chat-1m',
14
+ revision="main",
15
+ token=self.hf_token,
16
+ streaming=self.streaming
17
+ )
18
+ self.verbose = verbose
19
+ if verbose:
20
+ print(self.lmsys_dataset)
21
+ self.df_sample = None
22
+ self.df_prompts = None
23
+ self.unwrapped_turns_df = None
24
+
25
+ if not self.streaming and verbose:
26
+ print('Data is cached at:\n')
27
+ for file_info in self.lmsys_dataset['train'].cache_files:
28
+ filename = file_info['filename']
29
+ file_size = os.path.getsize(filename)
30
+ i = int((len(filename) - 41) / 2)
31
+ print(f"Filename: {filename[:i]}*{filename[-41:]}\nSize: {file_size} bytes")
32
+
33
+ def extract_df_sample(self, n_samples=None, conversation_ids=None):
34
+ """
35
+ Extracts a sample of conversations or specific conversations based on their conversation IDs.
36
+
37
+ Parameters:
38
+ - n_samples (int): Number of random samples to extract. Ignored if `conversation_ids` is provided.
39
+ - conversation_ids (list): List of conversation IDs to extract. If provided, this takes precedence over `n_samples`.
40
+
41
+ Returns:
42
+ - pd.DataFrame: A DataFrame containing the extracted conversations.
43
+ """
44
+ if conversation_ids:
45
+ # Filter conversations based on the provided conversation IDs
46
+ df_sample = self.lmsys_dataset['train'].to_pandas()
47
+ df_sample = df_sample[df_sample['conversation_id'].isin(conversation_ids)]
48
+ print(f"Retrieved {len(df_sample)} conversations based on specified IDs")
49
+ else:
50
+ # Randomly sample conversations if no IDs are provided
51
+ if not self.streaming:
52
+ df_sample = self.lmsys_dataset['train'].to_pandas().sample(n_samples)
53
+ print(f"Retrieved {len(df_sample)} random conversations from lmsys/lmsys-chat-1m")
54
+ else:
55
+ # Take a sample from the streamed dataset
56
+ streamed_samples = []
57
+ for i, row in enumerate(self.lmsys_dataset['train']):
58
+ streamed_samples.append(row)
59
+ if i + 1 == n_samples: # Collect only the desired number of samples
60
+ break
61
+ # Shuffle and convert the collected samples to a Pandas DataFrame
62
+ random.shuffle(streamed_samples)
63
+ df_sample = pd.DataFrame(streamed_samples)
64
+
65
+ self.df_sample = df_sample
66
+ if self.verbose and len(df_sample) > 4:
67
+ display(df_sample.head(2))
68
+ print('...')
69
+ display(df_sample.tail(2))
70
+ return df_sample
71
+
72
+ def parquet_sampling(self, n_samples):
73
+ base_url = "https://huggingface.co/datasets/lmsys/lmsys-chat-1m/resolve/main/data/"
74
+ data_files = [
75
+ "train-00000-of-00006-4feeb3f83346a0e9.parquet",
76
+ "train-00001-of-00006-4030672591c2f478.parquet",
77
+ "train-00002-of-00006-1779b7cec9462180.parquet",
78
+ "train-00003-of-00006-2fa862bfed56af1f.parquet",
79
+ "train-00004-of-00006-18f4bdd50c103e71.parquet",
80
+ "train-00005-of-00006-fe1acc5d10a9f0e2.parquet"
81
+ ]
82
+ sample_file = random.choice(data_files)
83
+ print(f"Sampling from {sample_file}")
84
+ data_files = {"train": base_url + sample_file}
85
+ parquet_sample = load_dataset("parquet", data_files=data_files, split="train")
86
+ df_sample = parquet_sample.to_pandas().sample(n_samples)
87
+ print(f"Retrieved {len(df_sample)} random conversations from lmsys/lmsys-chat-1m/{sample_file}")
88
+ self.df_sample = df_sample
89
+ if self.verbose and len(df_sample) > 4:
90
+ display(df_sample.head(2))
91
+ print('...')
92
+ display(df_sample.tail(2))
93
+ return df_sample
94
+
95
+ def add_turns_to_conversations(self):
96
+ """
97
+ Adds 'turn' keys to each conversation in the 'conversation' column of the dataframe.
98
+ """
99
+ self.df_sample['conversation'] = self.df_sample['conversation'].apply(
100
+ lambda conv: Conversation(conv).add_turns()
101
+ )
102
+ df_with_turns = self.df_sample
103
+ return df_with_turns
104
+
105
+ def unwrap_turns(self):
106
+ """
107
+ Creates a dataframe where each row corresponds to a pair of user-assistant messages in a conversation and turn.
108
+ The 'prompt' column contains the user's message, and the 'response' column contains the assistant's message.
109
+ Each row includes a 'turn_id' column, which numbers the turns uniquely per conversation.
110
+ """
111
+ paired_data = []
112
+ for _, row in self.df_sample.iterrows():
113
+ conversation_id = row['conversation_id']
114
+ row_data = row.to_dict()
115
+ row_data.pop('conversation') # Remove the 'conversation' field as it's being unwrapped
116
+
117
+ current_prompt = None
118
+ turn_id = None
119
+
120
+ for message in row['conversation']:
121
+ if message['role'] == 'user':
122
+ current_prompt = message['content']
123
+ turn_id = f"{conversation_id}{message['turn']:03}" # Create turn_id
124
+ elif message['role'] == 'assistant' and current_prompt is not None:
125
+ # Create a new row with the user-assistant pair
126
+ paired_row = {
127
+ **row_data,
128
+ 'turn_n': message['turn'],
129
+ 'prompt': current_prompt,
130
+ 'response': message['content'],
131
+ }
132
+ paired_data.append(paired_row)
133
+ current_prompt = None # Reset after pairing
134
+
135
+ unwrapped_turns_df = pd.DataFrame(paired_data)
136
+ unwrapped_turns_df.rename(columns={"turn": "conversation_turns"}, inplace=True) # The naming in the original dataset is ambiguous
137
+ self.unwrapped_turns_df = unwrapped_turns_df
138
+ return unwrapped_turns_df
139
+
140
+ def extract_prompts(self, filter_language=None, min_char_length=20, max_char_length=500, exclusions=None):
141
+ """
142
+ Extracts user prompts from the sample dataframe, optionally filtering by language and limiting the character length.
143
+
144
+ Parameters:
145
+ - filter_language (list of str or None): A list of specific languages to filter prompts by. If None, no language
146
+ filter is applied. Examples of valid values include ['English'], ['English', 'Portuguese'], or
147
+ ['Spanish', 'French', 'German'].
148
+ - min_char_length (int): The minimum character length for user prompts to include. Defaults to 20.
149
+ - max_char_length (int): The maximum character length for user prompts to include. Defaults to 500.
150
+ - exclusions (str or None): Path to a text file containing phrases. Prompts containing any of these phrases
151
+ will be excluded from the results. If None, no exclusions are applied.
152
+
153
+ Returns:
154
+ - pd.DataFrame: A DataFrame containing extracted prompts with columns 'prompt' and 'language'.
155
+ """
156
+ df_sample = self.df_sample
157
+ if filter_language:
158
+ extracted_data = df_sample[df_sample['language'].isin(filter_language)].apply(
159
+ lambda row: [
160
+ {'content': entry['content'], 'language': row['language']}
161
+ for entry in row['conversation']
162
+ if entry['role'] == 'user' and min_char_length <= len(entry['content']) <= max_char_length
163
+ ], axis=1
164
+ ).explode().dropna()
165
+ else:
166
+ extracted_data = df_sample.apply(
167
+ lambda row: [
168
+ {'content': entry['content'], 'language': row['language']}
169
+ for entry in row['conversation']
170
+ if entry['role'] == 'user' and min_char_length <= len(entry['content']) <= max_char_length
171
+ ], axis=1
172
+ ).explode().dropna()
173
+
174
+ df_prompts = pd.DataFrame(extracted_data.tolist())
175
+ df_prompts.rename(columns={'content': 'prompt'}, inplace=True)
176
+
177
+ orig_length = len(df_prompts)
178
+ if exclusions:
179
+ # Excluding prompts with phrases that are repeated often in this dataset
180
+ with open(exclusions, 'r') as f:
181
+ exclusions = [line.strip() for line in f.readlines()]
182
+ df_prompts = df_prompts[~df_prompts['prompt'].apply(lambda x: any(exclusion in x for exclusion in exclusions))]
183
+ print(f"Excluded {orig_length - len(df_prompts)} prompts.")
184
+
185
+ self.df_prompts = df_prompts
186
+ if self.verbose and len(df_sample) > 4:
187
+ display(df_prompts.head(2))
188
+ print('...')
189
+ display(df_prompts.tail(2))
190
+ return df_prompts
191
+
192
+ def extract_prompt_sample(self):
193
+ prompt_sample = self.df_prompts.sample(1)['prompt'].values[0]
194
+ if self.verbose:
195
+ wrapped_message = textwrap.fill(prompt_sample, width=120)
196
+ print(wrapped_message)
197
+ return prompt_sample
198
+
199
+ def search_conversations(self, search_term):
200
+ """
201
+ Searches the dataset for a given string and returns a DataFrame with matching records.
202
+
203
+ Parameters:
204
+ - search_term (str): The string to search for in the dataset.
205
+
206
+ Returns:
207
+ - pd.DataFrame: A DataFrame containing conversations where the search term is found.
208
+ """
209
+ if self.streaming:
210
+ raise ValueError("Search is not supported in streaming mode.")
211
+ df = self.lmsys_dataset['train'].to_pandas()
212
+ # Filter rows where the search term appears in the 'conversation' column
213
+ matching_records = df[df['conversation'].apply(
214
+ lambda conv: any(search_term.lower() in message['content'].lower() for message in conv)
215
+ )]
216
+ if self.verbose:
217
+ print(f"Found {len(matching_records)} matching conversations for search term: '{search_term}'")
218
+ return matching_records
219
+
220
+ def print_language_counts(self, df):
221
+ language_counts = df['language'].value_counts()
222
+ print("Language Record Counts:")
223
+ print(language_counts.to_frame('Count').reset_index().rename(columns={'index': 'Language'}))
224
+
225
+
226
+ class Conversation:
227
+ def __init__(self, conversation_data):
228
+ """
229
+ Initializes the Conversation object with the conversation data.
230
+
231
+ Parameters:
232
+ - conversation_data (list): A list of dictionaries representing a conversation.
233
+ """
234
+ self.conversation_data = conversation_data
235
+
236
+ def add_turns(self):
237
+ """
238
+ Adds a 'turn' key to each dictionary in the conversation,
239
+ identifying the turn (pair of user and assistant messages).
240
+
241
+ Returns:
242
+ - list: The updated conversation with 'turn' keys added.
243
+ """
244
+ turn_counter = 0
245
+ for message in self.conversation_data:
246
+ if message['role'] == 'user':
247
+ turn_counter += 1
248
+ message['turn'] = turn_counter
249
+ return self.conversation_data
250
+
251
+ def pretty_print(self, user_prefix, assistant_prefix, width=80):
252
+ """
253
+ Prints the conversation with specified prefixes and wrapped text.
254
+
255
+ Parameters:
256
+ - user_prefix (str): Prefix to prepend to user messages.
257
+ - assistant_prefix (str): Prefix to prepend to assistant messages.
258
+ - width (int): Maximum characters per line for wrapping.
259
+ """
260
+ wrapper = textwrap.TextWrapper(width=width)
261
+
262
+ for message in self.conversation_data:
263
+ if message['role'] == 'user':
264
+ prefix = user_prefix
265
+ elif message['role'] == 'assistant':
266
+ prefix = assistant_prefix
267
+ else:
268
+ continue # Ignore roles other than 'user' and 'assistant'
269
+
270
+ # Split on existing newlines, wrap each line, and join back with newlines
271
+ wrapped_content = "\n".join(
272
+ wrapper.fill(line) for line in message['content'].splitlines()
273
+ )
274
+ print(f"{prefix} {wrapped_content}\n")
src/lmsys_dataset_wrapper.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import textwrap
4
+ import random
5
+ import duckdb
6
+ import requests
7
+ import json
8
+ import tempfile
9
+ from datasets import load_dataset
10
+ from IPython.display import display
11
+ from collections import defaultdict
12
+
13
+ class DatasetWrapper:
14
+ def __init__(self, hf_token, dataset_name="lmsys/lmsys-chat-1m", verbose=True,
15
+ conversations_index="json/conversations_index.json", cache_size=50, request_timeout=20):
16
+ self.hf_token = hf_token
17
+ self.dataset_name = dataset_name
18
+ self.headers = {"Authorization": f"Bearer {self.hf_token}"}
19
+ self.timeout = request_timeout
20
+ self.cache_size = cache_size
21
+ self.verbose = verbose
22
+ parquet_list_url = f"https://datasets-server.huggingface.co/parquet?dataset={self.dataset_name}"
23
+ response = self._safe_get(parquet_list_url)
24
+ # Extract URLs from the response JSON
25
+ if response is not None:
26
+ self.parquet_urls = [file['url'] for file in response.json()['parquet_files']]
27
+ if self.verbose:
28
+ print("\nParquet URLs:")
29
+ for url in self.parquet_urls:
30
+ print(url)
31
+ head_response = self._safe_head(url)
32
+ file_size = int(head_response.headers['Content-Length'])
33
+ print(f"{url.split('/')[-1]}: {file_size} bytes")
34
+
35
+ # Loading the index
36
+ try:
37
+ with open(conversations_index, "r", encoding="utf-8") as f:
38
+ self.conversations_index = json.load(f)
39
+ except (FileNotFoundError, json.JSONDecodeError):
40
+ print(f"Conversations index file not found or invalid. Creating a new one at {conversations_index}.")
41
+ # Ensure directory exists
42
+ os.makedirs(os.path.dirname(conversations_index), exist_ok=True)
43
+ self.create_conversations_index(output_index_file=conversations_index)
44
+ with open(conversations_index, "r", encoding="utf-8") as f:
45
+ self.conversations_index = json.load(f)
46
+
47
+ # Initialize active conversation and DataFrame
48
+ # Read from "pkl/cached_chats.pkl" if available:
49
+ try:
50
+ self.active_df = pd.read_pickle("pkl/cached_chats.pkl")
51
+ print(f"Loaded {len(self.active_df)} cached chats")
52
+ self.active_df = self.active_df.sample(self.cache_size).reset_index(drop=True)
53
+ except (FileNotFoundError, ValueError):
54
+ self.active_df = pd.DataFrame()
55
+ print("No cached chats found")
56
+ if not self.active_df.empty:
57
+ try:
58
+ self.active_conversation = Conversation(self.active_df.iloc[0])
59
+ except Exception as e:
60
+ print(f"No conversations available: {e}")
61
+ else:
62
+ self.active_conversation = None
63
+
64
+ def _safe_get(self, url):
65
+ if self.timeout == 0:
66
+ print("Timeout is set to 0. Skipping GET request.")
67
+ return None
68
+ else:
69
+ try:
70
+ response = requests.get(url, headers=self.headers, timeout=self.timeout)
71
+ if response.status_code != 200:
72
+ raise ValueError(f"Failed to retrieve {url}. Status code: {response.status_code}")
73
+ return response
74
+ except requests.exceptions.Timeout:
75
+ print(f"Timeout occurred for GET {url}. Skipping.")
76
+ return None
77
+
78
+ def _safe_head(self, url):
79
+ if self.timeout == 0:
80
+ print("Timeout is set to 0. Skipping HEAD request.")
81
+ return None
82
+ try:
83
+ response = requests.head(url, allow_redirects=True, headers=self.headers, timeout=self.timeout)
84
+ return response
85
+ except requests.exceptions.Timeout:
86
+ print(f"Timeout occurred for GET {url}. Skipping.")
87
+ return None
88
+
89
+ def extract_sample_conversations(self, n_samples):
90
+ url = random.choice(self.parquet_urls)
91
+ print(f"Sampling conversations from {url}")
92
+ # Download file with auth headers using requests
93
+ r = self._safe_get(url)
94
+ if r is None:
95
+ print(f"Timeout occurred for GET {url}. Skipping sample extraction.")
96
+ return self.active_df
97
+ # Write the downloaded content into a temporary file
98
+ with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
99
+ tmp.write(r.content)
100
+ # tmp.flush()
101
+ tmp_path = tmp.name
102
+ try:
103
+ query_result = duckdb.query(f"SELECT * FROM read_parquet('{tmp_path}') USING SAMPLE {n_samples}").df()
104
+ self.active_df = query_result
105
+ try:
106
+ self.active_conversation = Conversation(query_result.iloc[0])
107
+ except Exception as e:
108
+ print(f"No conversations available: {e}")
109
+ finally:
110
+ # Clean up the temporary file
111
+ if os.path.exists(tmp_path):
112
+ os.unlink(tmp_path)
113
+
114
+ return query_result
115
+
116
+ def extract_conversations(self, conversation_ids):
117
+
118
+ # Create a lookup table for file names -> URLs
119
+ file_url_map = {url.split("/")[-1]: url for url in self.parquet_urls}
120
+
121
+ # Group conversation IDs by file
122
+ file_to_conversations = defaultdict(list)
123
+ for convid in conversation_ids:
124
+ if convid in self.conversations_index:
125
+ file_to_conversations[self.conversations_index[convid]].append(convid)
126
+
127
+ result_df = pd.DataFrame()
128
+
129
+ for file_name, conv_ids in file_to_conversations.items():
130
+ if file_name not in file_url_map:
131
+ print(f"File {file_name} not found in URL list, skipping.")
132
+ continue
133
+
134
+ file_url = file_url_map[file_name]
135
+ print(f"Querying file: {file_name} for {len(conv_ids)} conversations")
136
+
137
+ try:
138
+ r = self._safe_get(file_url)
139
+ if r == None:
140
+ print(f"Timeout occurred for GET {file_url}. Skipping file {file_name}.")
141
+ continue
142
+
143
+ with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
144
+ tmp.write(r.content)
145
+ tmp_path = tmp.name
146
+ try:
147
+ conv_id_list = "', '".join(conv_ids)
148
+ query_str = f"""
149
+ SELECT * FROM read_parquet('{tmp_path}')
150
+ WHERE conversation_id IN ('{conv_id_list}')
151
+ """
152
+ df = duckdb.query(query_str).df()
153
+ finally:
154
+ if os.path.exists(tmp_path):
155
+ os.unlink(tmp_path)
156
+
157
+ if not df.empty:
158
+ print(f"Found {len(df)} conversations in {file_name}")
159
+ result_df = pd.concat([result_df, df], ignore_index=True)
160
+
161
+ except Exception as e:
162
+ print(f"Error processing {file_name}: {e}")
163
+
164
+ self.active_df = result_df
165
+ try:
166
+ self.active_conversation = Conversation(self.active_df.iloc[0])
167
+ except Exception as e:
168
+ print(f"No conversations available: {e}")
169
+
170
+ return result_df
171
+
172
+ def literal_text_search(self, filter_str, min_results=1):
173
+ # If filter_str is empty, sample random conversations
174
+ if filter_str == "":
175
+ result_df = self.extract_sample_conversations(50)
176
+ urls = self.parquet_urls.copy()
177
+ random.shuffle(urls)
178
+
179
+ result_df = pd.DataFrame()
180
+
181
+ for url in urls:
182
+ print(f"Querying file: {url}")
183
+ r = self._safe_get(url)
184
+ if r == None:
185
+ print(f"Timeout occurred for GET {url}. Skipping file {url}.")
186
+ continue
187
+ with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
188
+ tmp.write(r.content)
189
+ tmp_path = tmp.name
190
+
191
+ try:
192
+ query_str = f"""
193
+ SELECT * FROM read_parquet('{tmp_path}')
194
+ WHERE contains(lower(cast(conversation as VARCHAR)), lower('{filter_str}'))
195
+ """
196
+ df = duckdb.query(query_str).df()
197
+ finally:
198
+ if os.path.exists(tmp_path):
199
+ os.unlink(tmp_path)
200
+
201
+ print(f"Found {len(df)} result(s) in {url.split('/')[-1]}")
202
+
203
+ if len(df) > 0:
204
+ result_df = pd.concat([result_df, df], ignore_index=True)
205
+
206
+ if len(result_df) >= min_results:
207
+ break
208
+ if len(result_df) == 0:
209
+ print("No results found. Returning empty DataFrame.")
210
+ placeholder_row = {'conversation_id': "No result found",
211
+ 'model': "-",
212
+ 'conversation': [
213
+ {'content': '-', 'role': 'user'},
214
+ {'content': '-', 'role': 'assistant'}
215
+ ],
216
+ 'turn': "-",
217
+ 'language': "-",
218
+ 'openai_moderation': "[{'-': '-', '-': '-'}]",
219
+ 'redacted': "-",}
220
+ result_df = pd.DataFrame([placeholder_row])
221
+ print(result_df)
222
+ self.active_df = result_df
223
+ try:
224
+ self.active_conversation = Conversation(self.active_df.iloc[0])
225
+ except Exception as e:
226
+ print(f"No conversations available: {e}")
227
+ return result_df
228
+
229
+ def create_conversations_index(self, output_index_file="json/conversations_index.json"):
230
+ """
231
+ Builds an index of conversation IDs from a list of Parquet file URLs.
232
+ Stores the index as a JSON mapping conversation IDs to their respective file names.
233
+ """
234
+ index = {}
235
+
236
+ for url in self.parquet_urls:
237
+ file_name = url.split('/')[-1] # Extract file name from URL
238
+ print(f"Indexing file: {file_name}")
239
+
240
+ try:
241
+ # Download the file temporarily
242
+ r = requests.get(url, headers=self.headers)
243
+ with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
244
+ tmp.write(r.content)
245
+ # tmp.flush()
246
+ tmp_path = tmp.name
247
+ try:
248
+ query = f"SELECT conversation_id FROM read_parquet('{tmp_path}')"
249
+ df = duckdb.query(query).to_df()
250
+ finally:
251
+ if os.path.exists(tmp_path):
252
+ os.unlink(tmp_path)
253
+
254
+ # Map conversation IDs to file name (not the full URL)
255
+ for _, row in df.iterrows():
256
+ index[row["conversation_id"]] = file_name
257
+
258
+ except Exception as e:
259
+ print(f"Error indexing {file_name}: {e}")
260
+
261
+ # Save index for fast lookup
262
+ with open(output_index_file, "w", encoding="utf-8") as f:
263
+ json.dump(index, f, indent=2)
264
+
265
+ return output_index_file
266
+
267
+
268
+ class Conversation:
269
+ def __init__(self, data):
270
+ """
271
+ Initialize a conversation object either from conversation data directly or from a DataFrame row.
272
+
273
+ Parameters:
274
+ - data: Can be either a list of conversation messages or a pandas Series/dict containing conversation data
275
+ """
276
+ # Handle both direct conversation data and DataFrame row
277
+ if isinstance(data, (pd.Series, dict)):
278
+ # Store all metadata separately
279
+ self.conversation_metadata = {}
280
+ for key, value in (data.items() if isinstance(data, pd.Series) else data.items()):
281
+ if key == 'conversation':
282
+ self.conversation_data = value
283
+ else:
284
+ self.conversation_metadata[key] = value
285
+ else:
286
+ # Direct initialization with conversation data
287
+ self.conversation_data = data
288
+ self.conversation_metadata = {}
289
+
290
+ def add_turns(self):
291
+ """
292
+ Adds a 'turn' key to each dictionary in the conversation,
293
+ identifying the turn (pair of user and assistant messages).
294
+
295
+ Returns:
296
+ - list: The updated conversation with 'turn' keys added.
297
+ """
298
+ turn_counter = 0
299
+ for message in self.conversation_data:
300
+ if message['role'] == 'user':
301
+ turn_counter += 1
302
+ message['turn'] = turn_counter
303
+ return self.conversation_data
304
+
305
+ def pretty_print(self, user_prefix, assistant_prefix, width=80):
306
+ """
307
+ Prints the conversation with specified prefixes and wrapped text.
308
+
309
+ Parameters:
310
+ - user_prefix (str): Prefix to prepend to user messages.
311
+ - assistant_prefix (str): Prefix to prepend to assistant messages.
312
+ - width (int): Maximum characters per line for wrapping.
313
+ """
314
+ wrapper = textwrap.TextWrapper(width=width)
315
+
316
+ for message in self.conversation_data:
317
+ if message['role'] == 'user':
318
+ prefix = user_prefix
319
+ elif message['role'] == 'assistant':
320
+ prefix = assistant_prefix
321
+ else:
322
+ continue # Ignore roles other than 'user' and 'assistant'
323
+
324
+ # Split on existing newlines, wrap each line, and join back with newlines
325
+ wrapped_content = "\n".join(
326
+ wrapper.fill(line) for line in message['content'].splitlines()
327
+ )
328
+ print(f"{prefix} {wrapped_content}\n")
src/text_classification_functions.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
2
+ from datasets import Dataset
3
+ from tqdm import tqdm
4
+ import torch
5
+ import numpy as np
6
+ import os
7
+ from langdetect import detect
8
+ from sklearn.metrics import accuracy_score, f1_score, log_loss, confusion_matrix, ConfusionMatrixDisplay
9
+ import matplotlib.pyplot as plt
10
+
11
+ class Classifier:
12
+ def __init__(self, model_path, label_map, verbose = False):
13
+ self.model_path = model_path
14
+ self.classifier = pipeline("text-classification", model=model_path, tokenizer=model_path, device=0 if torch.cuda.is_available() else -1)
15
+ self.label_map = label_map
16
+ if verbose:
17
+ self.print_device_information()
18
+
19
+ def print_device_information(self):
20
+ # Check device information
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+ device_properties = torch.cuda.get_device_properties(0) if device.type == "cuda" else "CPU Device"
23
+
24
+ print(f"Using device: {device}")
25
+ if device.type == "cuda":
26
+ print(f"Device Name: {device_properties.name}")
27
+ # print(f"Compute Capability: {device_properties.major}.{device_properties.minor}")
28
+ print(f"Total Memory: {device_properties.total_memory / 1e9:.2f} GB")
29
+
30
+ def tokenize_and_trim(self, text):
31
+ max_length = self.classifier.tokenizer.model_max_length
32
+ inputs = self.classifier.tokenizer(text, truncation=True, max_length=max_length, return_tensors="tf")
33
+ return self.classifier.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)
34
+
35
+
36
+ def classify_dataframe_column(self, df, target_column, feature_suffix):
37
+
38
+ tqdm.pandas()
39
+ df[f'trimmed_{target_column}'] = df[target_column].progress_apply(self.tokenize_and_trim)
40
+
41
+ results = []
42
+ for text in tqdm(df[f'trimmed_{target_column}'].tolist(), desc="Classifying"):
43
+ result = self.classifier(text)
44
+ results.append(result[0])
45
+
46
+ df[f'pred_label_{feature_suffix}'] = [self.label_map[int(result['label'].split('_')[-1])] for result in results]
47
+ df[f'prob_{feature_suffix}'] = [result['score'] for result in results]
48
+ df.drop(columns=[f'trimmed_{target_column}'], inplace=True)
49
+ return df
50
+
51
+ def test_model_predictions(self, df, target_column):
52
+ """
53
+ Tests model predictions on a given dataframe column and computes evaluation metrics.
54
+
55
+ Args:
56
+ df (pd.DataFrame): Input dataframe containing the data.
57
+ target_column (str): The name of the column to classify.
58
+
59
+ Requirements:
60
+ - The dataframe must include a 'label' column for comparison with predictions.
61
+
62
+ Returns:
63
+ dict: A dictionary containing accuracy, F1 score, cross-entropy loss,
64
+ and the confusion matrix.
65
+ """
66
+ # Convert pandas dataframe to Dataset
67
+ dataset = Dataset.from_pandas(df)
68
+
69
+ # Define a processing function for tokenization and classification
70
+ def process_data(batch):
71
+ trimmed_text = self.tokenize_and_trim(batch[target_column])
72
+ result = self.classifier(trimmed_text)
73
+ score = result[0]['score']
74
+ label = result[0]['label']
75
+ return {
76
+ 'trimmed_text': trimmed_text,
77
+ 'predicted_prob_0': score if label == 'LABEL_0' else 1 - score,
78
+ 'predicted_prob_1': 1 - score if label == 'LABEL_0' else score,
79
+ }
80
+
81
+ # Apply processing with map
82
+ processed_dataset = dataset.map(process_data, batched=False)
83
+
84
+ # Convert back to pandas dataframe
85
+ processed_df = processed_dataset.to_pandas()
86
+
87
+ # Extract predicted probabilities and true labels
88
+ predicted_probs = processed_df[['predicted_prob_0', 'predicted_prob_1']].values
89
+ true_labels = df['label'].values
90
+
91
+ # Calculate metrics
92
+ accuracy = accuracy_score(true_labels, np.argmax(predicted_probs, axis=1))
93
+ f1 = f1_score(true_labels, np.argmax(predicted_probs, axis=1), average='weighted')
94
+ cross_entropy_loss = log_loss(true_labels, predicted_probs)
95
+
96
+ # Print metrics
97
+ print(f"Accuracy: {accuracy:.4f}")
98
+ print(f"F1 Score: {f1:.4f}")
99
+ print(f"Cross Entropy Loss: {cross_entropy_loss:.4f}")
100
+
101
+ # Confusion matrix
102
+ cm = confusion_matrix(true_labels, np.argmax(predicted_probs, axis=1))
103
+ cmap = plt.cm.Blues
104
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
105
+ disp.plot(cmap=cmap)
106
+ plt.show()
107
+
108
+ # Return metrics and probabilities for further inspection
109
+ return {
110
+ "accuracy": accuracy,
111
+ "f1_score": f1,
112
+ "cross_entropy_loss": cross_entropy_loss,
113
+ "confusion_matrix": cm,
114
+ "predicted_probs": predicted_probs # Include reconstructed probabilities
115
+ }
116
+
117
+
118
+ class LanguageDetector:
119
+ def __init__(self, dataframe):
120
+ """
121
+ Initializes the LanguageDetector with the provided DataFrame.
122
+ """
123
+ self.dataframe = dataframe
124
+
125
+ def detect_language_dataframe_column(self, target_column):
126
+ """
127
+ Detects the language of text in the specified column using langdetect and adds
128
+ a 'detected_language' column to the DataFrame.
129
+ """
130
+ def detect_language(text):
131
+ try:
132
+ return detect(text)
133
+ except Exception:
134
+ return None
135
+
136
+ tqdm.pandas()
137
+ self.dataframe['detected_language'] = self.dataframe[target_column].progress_apply(detect_language)
138
+
139
+ return self.dataframe
140
+
141
+
142
+ # Classifier with Tensorflow backend
143
+ class TensorflowClassifier(Classifier):
144
+ def __init__(self, model_path, label_map, verbose=False):
145
+ super().__init__(model_path, label_map, verbose=False)
146
+ self.is_tensorflow = False
147
+
148
+ if self._is_tensorflow_model(model_path):
149
+ self.model = tf.keras.models.load_model(model_path)
150
+ self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # Adjust as per training tokenizer
151
+ self.is_tensorflow = True
152
+ if verbose:
153
+ print("Loaded TensorFlow model.")
154
+ else:
155
+ if verbose:
156
+ print("Fallback to HuggingFace pipeline.")
157
+
158
+ def _is_tensorflow_model(self, model_path):
159
+ return os.path.isdir(model_path) and os.path.exists(os.path.join(model_path, "saved_model.pb"))
160
+
161
+ def classify(self, text):
162
+ if self.is_tensorflow:
163
+ inputs = self.tokenizer(text, truncation=True, max_length=self.tokenizer.model_max_length, return_tensors="np")
164
+ logits = self.model.predict([inputs["input_ids"], inputs["attention_mask"]])
165
+ probabilities = tf.nn.softmax(logits).numpy()
166
+ label_id = np.argmax(probabilities, axis=-1).item()
167
+ return {
168
+ "label": f"LABEL_{label_id}",
169
+ "score": probabilities.max()
170
+ }
171
+ else:
172
+ return self.classifier(text)[0]
173
+
174
+ def classify_dataframe_column(self, df, target_column, feature_suffix):
175
+ tqdm.pandas()
176
+ df[f'trimmed_{target_column}'] = df[target_column].progress_apply(
177
+ lambda text: self.tokenizer.decode(
178
+ self.tokenizer(text, truncation=True, max_length=self.tokenizer.model_max_length)["input_ids"],
179
+ skip_special_tokens=True
180
+ )
181
+ )
182
+
183
+ if self.is_tensorflow:
184
+ results = [self.classify(text) for text in df[f'trimmed_{target_column}']]
185
+ else:
186
+ results = [self.classifier(text)[0] for text in df[f'trimmed_{target_column}']]
187
+
188
+ df[f'pred_label_{feature_suffix}'] = [
189
+ self.label_map[int(result['label'].split('_')[-1])] for result in results
190
+ ]
191
+ df[f'prob_{feature_suffix}'] = [result['score'] for result in results]
192
+ df.drop(columns=[f'trimmed_{target_column}'], inplace=True)
193
+ return df
194
+
195
+
196
+ class ZeroShotClassifier(Classifier):
197
+
198
+ def __init__(self, model_path, tokenizer_path, candidate_labels):
199
+ self.model_path = model_path
200
+ self.candidate_labels = candidate_labels
201
+ self.classifier = pipeline("zero-shot-classification", model=model_path, tokenizer=tokenizer_path, clean_up_tokenization_spaces=True, device=0 if torch.cuda.is_available() else -1)
202
+
203
+ def classify_text(self, text, top_n=None, multi_label=False):
204
+ """
205
+ Classify a single text using zero-shot classification with truncated scores.
206
+
207
+ :param text: The text to classify
208
+ :param multi_label: Whether to allow multi-label classification
209
+ :return: Classification result as a dictionary with scores truncated to 3 decimals
210
+ """
211
+ classification_output = self.classifier(text, self.candidate_labels, multi_label=multi_label, clean_up_tokenization_spaces=True)
212
+ classification_output['scores'] = [round(score, 3) for score in classification_output['scores']]
213
+ if top_n is not None:
214
+ classification_output = {
215
+ 'sequence': classification_output['sequence'],
216
+ 'labels': classification_output['labels'][:top_n],
217
+ 'scores': classification_output['scores'][:top_n]
218
+ }
219
+ return classification_output
220
+
221
+ def classify_dataframe_column(self, df, target_column, feature_suffix, multi_label=False):
222
+ """
223
+ Classify the contents of a dataframe column using zero-shot classification.
224
+
225
+ :param df: The dataframe to process
226
+ :param target_column: The column containing text to classify
227
+ :param feature_suffix: Suffix for the output columns
228
+ :param multi_label: Whether to allow multi-label classification
229
+ :return: The dataframe with classification results
230
+ """
231
+ tqdm.pandas()
232
+
233
+ # Apply the classify_text method to each row
234
+ results = df[target_column].progress_apply(
235
+ lambda text: self.classify_text(text, multi_label=multi_label)
236
+ )
237
+
238
+ # Extract and store results
239
+ df[f'top_class_{feature_suffix}'] = results.apply(lambda res: res['labels'][0])
240
+ df[f'top_score_{feature_suffix}'] = results.apply(lambda res: res['scores'][0])
241
+ df[f'full_results_{feature_suffix}'] = results.apply(lambda res: list(zip(res['labels'], res['scores'])))
242
+
243
+ return df
244
+
245
+ def test_zs_predictions(self, df, target_column='text', true_classes_column='category', plot_conf_matrix=True):
246
+ """
247
+ Tests model predictions on a given dataset column using the zero-shot classification pipeline.
248
+
249
+ Args:
250
+ df (pd.DataFrame): Input dataframe containing texts for zero-shot classification.
251
+ target_column (str): The name of the column containing text to classify.
252
+ true_classes_column (str): The column containing annotated classes.
253
+
254
+ Returns:
255
+ dict: A dictionary containing accuracy, F1 score, and confusion matrix.
256
+ """
257
+ # Progress bar for classification
258
+ tqdm.pandas(desc=f"Zero-shot classification with {self.model_path}")
259
+
260
+ # Function to classify each row
261
+ def classify_row(row):
262
+ classification_output = self.classifier(
263
+ row[target_column],
264
+ self.candidate_labels,
265
+ multi_label=False,
266
+ clean_up_tokenization_spaces=True,
267
+ )
268
+ return classification_output["labels"][0]
269
+
270
+ # Apply classification with progress bar
271
+ df = df.copy()
272
+ df.loc[:, 'predicted_class'] = df.progress_apply(classify_row, axis=1)
273
+
274
+ # Extract true and predicted classes
275
+ true_classes = df[true_classes_column]
276
+ predicted_classes = df['predicted_class']
277
+
278
+ # Compute metrics
279
+ accuracy = accuracy_score(true_classes, predicted_classes)
280
+ f1 = f1_score(true_classes, predicted_classes, average="macro")
281
+ cm = confusion_matrix(true_classes, predicted_classes, labels=self.candidate_labels)
282
+ if plot_conf_matrix:
283
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=self.candidate_labels)
284
+ fig, ax = plt.subplots(figsize=(4, 4))
285
+ disp.plot(cmap=plt.cm.Blues, ax=ax, colorbar=False)
286
+ ax.set_title(f"Zero-shot classification with {self.model_path}", fontsize=10)
287
+ ax.set_xlabel("Predicted label", fontsize=8)
288
+ ax.set_ylabel("True label", fontsize=8)
289
+
290
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=8)
291
+ ax.set_yticklabels(ax.get_yticklabels(), fontsize=8)
292
+
293
+ fig.text(
294
+ 0.5, 0.01,
295
+ f"Accuracy: {accuracy:.4f} | F1 Score: {f1:.4f}",
296
+ ha="center",
297
+ fontsize=10
298
+ )
299
+ plt.tight_layout(rect=[0, 0.05, 1, 1]) # Adjust bottom margin
300
+ plt.show()
301
+
302
+ return {
303
+ "accuracy": accuracy,
304
+ "f1_score": f1,
305
+ "confusion_matrix": cm,
306
+ "detailed_results": df.to_dict(), # Full dataframe with predictions
307
+ }
308
+
309
+ def test_zs_predictions_with_dataset(self, df, target_column='text', true_classes_column='category', plot_conf_matrix=True):
310
+ dataset = Dataset.from_pandas(df)
311
+ def classify_text(batch):
312
+ classification_output = self.classifier(
313
+ batch[target_column],
314
+ self.candidate_labels,
315
+ multi_label=False,
316
+ clean_up_tokenization_spaces=True,
317
+ )
318
+ return {
319
+ "predicted_class": classification_output["labels"][0],
320
+ "predicted_scores": classification_output["scores"],
321
+ }
322
+
323
+ # Apply classification to the dataset
324
+ classified_dataset = dataset.map(classify_text, batched=False)
325
+ # classified_dataset = dataset.map(classify_text, batched=True, batch_size=16)
326
+
327
+ # Extract true and predicted classes
328
+ true_classes = classified_dataset[true_classes_column]
329
+ predicted_classes = classified_dataset["predicted_class"]
330
+
331
+ # Compute metrics
332
+ accuracy = accuracy_score(true_classes, predicted_classes)
333
+ f1 = f1_score(true_classes, predicted_classes, average="macro")
334
+
335
+ # Print metrics
336
+ print(f"Accuracy: {accuracy:.4f}")
337
+ print(f"F1 Score: {f1:.4f}")
338
+
339
+ # Generate confusion matrix:
340
+ cm = confusion_matrix(true_classes, predicted_classes, labels=self.candidate_labels)
341
+ if plot_conf_matrix:
342
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=self.candidate_labels)
343
+ fig, ax = plt.subplots(figsize=(6, 6))
344
+ disp.plot(cmap=plt.cm.Blues, ax=ax)
345
+ plt.xticks(rotation=45, ha="right")
346
+ plt.show()
347
+
348
+ # Return metrics for further inspection
349
+ return {
350
+ "accuracy": accuracy,
351
+ "f1_score": f1,
352
+ "confusion_matrix": cm,
353
+ "detailed_results": classified_dataset.to_dict(),
354
+ }
355
+
356
+ class MetricsComparison:
357
+ def __init__(self, base_classifier, fine_tuned_classifier, base_metrics, fine_tuned_metrics):
358
+ self.base_classifier = base_classifier
359
+ self.fine_tuned_classifier = fine_tuned_classifier
360
+ self.base_metrics = base_metrics
361
+ self.fine_tuned_metrics = fine_tuned_metrics
362
+
363
+ def compare_conf_matrices(self):
364
+ fig, axes = plt.subplots(1, 2, figsize=(12, 6))
365
+ # Plot for base_classifier (left)
366
+ disp1 = ConfusionMatrixDisplay(confusion_matrix=self.base_metrics["confusion_matrix"],
367
+ display_labels=self.base_classifier.candidate_labels)
368
+ disp1.plot(cmap=plt.cm.Blues, ax=axes[0], colorbar=False)
369
+ axes[0].set_title(f"Zero-shot classification with {self.base_classifier.model_path}", fontsize=10)
370
+ axes[0].set_xlabel("Predicted class", fontsize=8)
371
+ axes[0].set_ylabel("True class", fontsize=8)
372
+ axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=45, ha="right", fontsize=8)
373
+ axes[0].set_yticklabels(axes[0].get_yticklabels(), fontsize=8)
374
+
375
+ fig.text(
376
+ 0.25, 0.01,
377
+ f"Accuracy: {self.base_metrics['accuracy']:.4f} | F1 Score: {self.base_metrics['f1_score']:.4f}",
378
+ ha="center",
379
+ fontsize=10
380
+ )
381
+
382
+ # Plot for zs_classifier (fine-tuned) (right)
383
+ disp2 = ConfusionMatrixDisplay(confusion_matrix=self.fine_tuned_metrics["confusion_matrix"],
384
+ display_labels=self.fine_tuned_classifier.candidate_labels)
385
+ disp2.plot(cmap=plt.cm.Blues, ax=axes[1], colorbar=False)
386
+ axes[1].set_title(f"ZS classification with {self.fine_tuned_classifier.model_path}", fontsize=10)
387
+ axes[1].set_xlabel("Predicted class", fontsize=8)
388
+ axes[1].set_ylabel("True class", fontsize=8)
389
+ axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=45, ha="right", fontsize=8)
390
+ axes[1].set_yticklabels(axes[1].get_yticklabels(), fontsize=8)
391
+
392
+ fig.text(
393
+ 0.75, 0.01,
394
+ f"Accuracy: {self.fine_tuned_metrics['accuracy']:.4f} | F1 Score: {self.fine_tuned_metrics['f1_score']:.4f}",
395
+ ha="center",
396
+ fontsize=10
397
+ )
398
+
399
+ plt.tight_layout(rect=[0, 0.05, 1, 0.95])
400
+ plt.show()
401
+