alisrbdni's picture
Update app.py
37dd487 verified
raw
history blame
47.3 kB
# # %%writefile app.py
# import streamlit as st
# import matplotlib.pyplot as plt
# import torch
# from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
# from datasets import load_dataset, Dataset
# from evaluate import load as load_metric
# from torch.utils.data import DataLoader
# import pandas as pd
# import random
# from collections import OrderedDict
# import flwr as fl
# from logging import INFO, DEBUG
# from flwr.common.logger import log
# import logging
# import streamlit
# # If you're curious of all the loggers
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
# def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
# raw_datasets = load_dataset(dataset_name)
# raw_datasets = raw_datasets.shuffle(seed=42)
# del raw_datasets["unsupervised"]
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# def tokenize_function(examples):
# return tokenizer(examples["text"], truncation=True)
# tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
# tokenized_datasets = tokenized_datasets.remove_columns("text")
# tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
# train_datasets = []
# test_datasets = []
# for _ in range(num_clients):
# train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
# test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
# train_datasets.append(train_dataset)
# test_datasets.append(test_dataset)
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# return train_datasets, test_datasets, data_collator, raw_datasets
# def train(net, trainloader, epochs):
# optimizer = AdamW(net.parameters(), lr=5e-5)
# net.train()
# for _ in range(epochs):
# for batch in trainloader:
# batch = {k: v.to(DEVICE) for k, v in batch.items()}
# outputs = net(**batch)
# loss = outputs.loss
# loss.backward()
# optimizer.step()
# optimizer.zero_grad()
# # class SaveModelStrategy(fl.server.strategy.FedAvg):
# # def aggregate_fit(
# # self,
# # server_round: int,
# # results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
# # failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
# # ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
# # """Aggregate model weights using weighted average and store checkpoint"""
# # # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
# # aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
# # if aggregated_parameters is not None:
# # print(f"Saving round {server_round} aggregated_parameters...")
# # # Convert `Parameters` to `List[np.ndarray]`
# # aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
# # # Convert `List[np.ndarray]` to PyTorch`state_dict`
# # params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
# # state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
# # net.load_state_dict(state_dict, strict=True)
# # # Save the model
# # torch.save(net.state_dict(), f"model_round_{server_round}.pth")
# # return aggregated_parameters, aggregated_metrics
# def test(net, testloader):
# metric = load_metric("accuracy")
# net.eval()
# loss = 0
# for batch in testloader:
# batch = {k: v.to(DEVICE) for k, v in batch.items()}
# with torch.no_grad():
# outputs = net(**batch)
# logits = outputs.logits
# loss += outputs.loss.item()
# predictions = torch.argmax(logits, dim=-1)
# metric.add_batch(predictions=predictions, references=batch["labels"])
# loss /= len(testloader)
# accuracy = metric.compute()["accuracy"]
# return loss, accuracy
# class CustomClient(fl.client.NumPyClient):
# def __init__(self, net, trainloader, testloader, client_id):
# self.net = net
# self.trainloader = trainloader
# self.testloader = testloader
# self.client_id = client_id
# self.losses = []
# self.accuracies = []
# def get_parameters(self, config):
# return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
# def set_parameters(self, parameters):
# params_dict = zip(self.net.state_dict().keys(), parameters)
# state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
# self.net.load_state_dict(state_dict, strict=True)
# def fit(self, parameters, config):
# log(INFO, f"Client {self.client_id} is starting fit()")
# self.set_parameters(parameters)
# train(self.net, self.trainloader, epochs=1)
# loss, accuracy = test(self.net, self.testloader)
# self.losses.append(loss)
# self.accuracies.append(accuracy)
# log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
# return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy}
# def evaluate(self, parameters, config):
# log(INFO, f"Client {self.client_id} is starting evaluate()")
# self.set_parameters(parameters)
# loss, accuracy = test(self.net, self.testloader)
# log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
# return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)}
# def plot_metrics(self, round_num, plot_placeholder):
# if self.losses and self.accuracies:
# plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
# plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
# plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
# fig, ax1 = plt.subplots()
# color = 'tab:red'
# ax1.set_xlabel('Round')
# ax1.set_ylabel('Loss', color=color)
# ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
# ax1.tick_params(axis='y', labelcolor=color)
# ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
# color = 'tab:blue'
# ax2.set_ylabel('Accuracy', color=color)
# ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
# ax2.tick_params(axis='y', labelcolor=color)
# fig.tight_layout()
# plot_placeholder.pyplot(fig)
# import matplotlib.pyplot as plt
# import re
# def read_log_file(log_path='./log.txt'):
# with open(log_path, 'r') as file:
# log_lines = file.readlines()
# return log_lines
# def parse_log(log_lines):
# rounds = []
# clients = {}
# memory_usage = []
# round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)')
# client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
# memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
# current_round = None
# for line in log_lines:
# round_match = round_pattern.search(line)
# client_match = client_pattern.search(line)
# memory_match = memory_pattern.search(line)
# if round_match:
# current_round = int(round_match.group(1))
# rounds.append(current_round)
# elif client_match:
# client_id = int(client_match.group(1))
# log_level = client_match.group(2)
# message = client_match.group(3)
# if client_id not in clients:
# clients[client_id] = {'rounds': [], 'messages': []}
# clients[client_id]['rounds'].append(current_round)
# clients[client_id]['messages'].append((log_level, message))
# elif memory_match:
# memory_usage.append(float(memory_match.group(1)))
# return rounds, clients, memory_usage
# def plot_metrics(rounds, clients, memory_usage):
# st.write("## Metrics Overview")
# st.write("### Memory Usage")
# plt.figure()
# plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
# plt.xlabel('Step')
# plt.ylabel('Memory Usage (GB)')
# plt.legend()
# st.pyplot(plt)
# for client_id, data in clients.items():
# st.write(f"### Client {client_id} Metrics")
# info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
# debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
# st.write("#### INFO Messages")
# for msg in info_messages:
# st.write(msg)
# st.write("#### DEBUG Messages")
# for msg in debug_messages:
# st.write(msg)
# # Placeholder for actual loss and accuracy values, assuming they're included in the messages
# losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
# accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
# if losses:
# plt.figure()
# plt.plot(data['rounds'], losses, label='Loss')
# plt.xlabel('Round')
# plt.ylabel('Loss')
# plt.legend()
# st.pyplot(plt)
# if accuracies:
# plt.figure()
# plt.plot(data['rounds'], accuracies, label='Accuracy')
# plt.xlabel('Round')
# plt.ylabel('Accuracy')
# plt.legend()
# st.pyplot(plt)
# def read_log_file2():
# with open("./log.txt", "r") as file:
# return file.read()
# def main():
# st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
# dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
# model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
# NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
# NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
# train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS)
# trainloaders = []
# testloaders = []
# clients = []
# for i in range(NUM_CLIENTS):
# st.write(f"### Client {i+1} Datasets")
# train_df = pd.DataFrame(train_datasets[i])
# test_df = pd.DataFrame(test_datasets[i])
# st.write("#### Train Dataset (Words)")
# st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20)))
# st.write("#### Train Dataset (Tokens)")
# edited_train_df = st.data_editor(train_df, key=f"train_{i}")
# st.write("#### Test Dataset (Words)")
# st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20)))
# st.write("#### Test Dataset (Tokens)")
# edited_test_df = st.data_editor(test_df, key=f"test_{i}")
# edited_train_dataset = Dataset.from_pandas(edited_train_df)
# edited_test_dataset = Dataset.from_pandas(edited_test_df)
# trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
# testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
# trainloaders.append(trainloader)
# testloaders.append(testloader)
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
# client = CustomClient(net, trainloader, testloader, client_id=i+1)
# clients.append(client)
# if st.button("Start Training"):
# def client_fn(cid):
# return clients[int(cid)]
# def weighted_average(metrics):
# accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
# losses = [num_examples * m["loss"] for num_examples, m in metrics]
# examples = [num_examples for num_examples, _ in metrics]
# return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
# strategy = fl.server.strategy.FedAvg(
# fraction_fit=1.0,
# fraction_evaluate=1.0,
# evaluate_metrics_aggregation_fn=weighted_average,
# )
# for round_num in range(NUM_ROUNDS):
# st.write(f"### Round {round_num + 1} ✅")
# st.markdown(print(st.logger._loggers))
# st.markdown(read_log_file2())
# logs = read_log_file2()
# import re
# import plotly.graph_objects as go
# import streamlit as st
# import pandas as pd
# # Log data
# log_data = logs
# # Extract relevant data
# accuracy_pattern = re.compile(r"'accuracy': \((\d+),([\d.]+)\)\((\d+), ([\d.]+)\)")
# loss_pattern = re.compile(r"'loss': \((\d+),([\d.]+)\)\((\d+), ([\d.]+)\)")
# accuracy_matches = accuracy_pattern.findall(log_data)
# loss_matches = loss_pattern.findall(log_data)
# rounds = [int(match[0]) for match in accuracy_matches]
# accuracies = [float(match[1]) for match in accuracy_matches]
# losses = [float(match[1]) for match in loss_matches]
# # Create accuracy plot
# accuracy_fig = go.Figure()
# accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
# accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
# # Create loss plot
# loss_fig = go.Figure()
# loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
# loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
# # Display plots in Streamlit
# st.plotly_chart(accuracy_fig)
# st.plotly_chart(loss_fig)
# # Display data table
# data = {
# 'Round': rounds,
# 'Accuracy': accuracies,
# 'Loss': losses
# }
# df = pd.DataFrame(data)
# st.write("## Training Metrics")
# st.table(df)
# plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
# fl.simulation.start_simulation(
# client_fn=client_fn,
# num_clients=NUM_CLIENTS,
# config=fl.server.ServerConfig(num_rounds=1),
# strategy=strategy,
# client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)},
# ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}
# )
# for i, client in enumerate(clients):
# client.plot_metrics(round_num + 1, plot_placeholders[i])
# st.write(" ")
# st.success("Training completed successfully!")
# # Display final metrics
# st.write("## Final Client Metrics")
# for client in clients:
# st.write(f"### Client {client.client_id}")
# if client.losses and client.accuracies:
# st.write(f"Final Loss: {client.losses[-1]:.4f}")
# st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
# client.plot_metrics(NUM_ROUNDS, st.empty())
# else:
# st.write("No metrics available.")
# st.write(" ")
# # Display log.txt content
# st.write("## Training Log")
# # st.text(read_log_file())
# st.write("## Training Log Analysis")
# log_lines = read_log_file()
# rounds, clients, memory_usage = parse_log(log_lines)
# plot_metrics(rounds, clients, memory_usage)
# else:
# st.write("Click the 'Start Training' button to start the training process.")
# if __name__ == "__main__":
# main()
# ##############NEW
# import streamlit as st
# import matplotlib.pyplot as plt
# import torch
# from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
# from datasets import load_dataset, Dataset
# from evaluate import load as load_metric
# from torch.utils.data import DataLoader
# import pandas as pd
# import random
# from collections import OrderedDict
# import flwr as fl
# from logging import INFO, DEBUG
# from flwr.common.logger import log
# import logging
# import re
# import plotly.graph_objects as go
# # If you're curious of all the loggers
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
# def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
# raw_datasets = load_dataset(dataset_name)
# raw_datasets = raw_datasets.shuffle(seed=42)
# del raw_datasets["unsupervised"]
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# def tokenize_function(examples):
# return tokenizer(examples["text"], truncation=True)
# tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
# tokenized_datasets = tokenized_datasets.remove_columns("text")
# tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
# train_datasets = []
# test_datasets = []
# for _ in range(num_clients):
# train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
# test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
# train_datasets.append(train_dataset)
# test_datasets.append(test_dataset)
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# return train_datasets, test_datasets, data_collator, raw_datasets
# def train(net, trainloader, epochs):
# optimizer = AdamW(net.parameters(), lr=5e-5)
# net.train()
# for _ in range(epochs):
# for batch in trainloader:
# batch = {k: v.to(DEVICE) for k, v in batch.items()}
# outputs = net(**batch)
# loss = outputs.loss
# loss.backward()
# optimizer.step()
# optimizer.zero_grad()
# def test(net, testloader):
# metric = load_metric("accuracy")
# net.eval()
# loss = 0
# for batch in testloader:
# batch = {k: v.to(DEVICE) for k, v in batch.items()}
# with torch.no_grad():
# outputs = net(**batch)
# logits = outputs.logits
# loss += outputs.loss.item()
# predictions = torch.argmax(logits, dim=-1)
# metric.add_batch(predictions=predictions, references=batch["labels"])
# loss /= len(testloader)
# accuracy = metric.compute()["accuracy"]
# return loss, accuracy
# class CustomClient(fl.client.NumPyClient):
# def __init__(self, net, trainloader, testloader, client_id):
# self.net = net
# self.trainloader = trainloader
# self.testloader = testloader
# self.client_id = client_id
# self.losses = []
# self.accuracies = []
# def get_parameters(self, config):
# return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
# def set_parameters(self, parameters):
# params_dict = zip(self.net.state_dict().keys(), parameters)
# state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
# self.net.load_state_dict(state_dict, strict=True)
# def fit(self, parameters, config):
# log(INFO, f"Client {self.client_id} is starting fit()")
# self.set_parameters(parameters)
# train(self.net, self.trainloader, epochs=1)
# loss, accuracy = test(self.net, self.testloader)
# self.losses.append(loss)
# self.accuracies.append(accuracy)
# log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
# return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy}
# def evaluate(self, parameters, config):
# log(INFO, f"Client {self.client_id} is starting evaluate()")
# self.set_parameters(parameters)
# loss, accuracy = test(self.net, self.testloader)
# log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
# return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)}
# def plot_metrics(self, round_num, plot_placeholder):
# if self.losses and self.accuracies:
# plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
# plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
# plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
# fig, ax1 = plt.subplots()
# color = 'tab:red'
# ax1.set_xlabel('Round')
# ax1.set_ylabel('Loss', color=color)
# ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
# ax1.tick_params(axis='y', labelcolor=color)
# ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
# color = 'tab:blue'
# ax2.set_ylabel('Accuracy', color=color)
# ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
# ax2.tick_params(axis='y', labelcolor=color)
# fig.tight_layout()
# plot_placeholder.pyplot(fig)
# def read_log_file(log_path='./log.txt'):
# with open(log_path, 'r') as file:
# log_lines = file.readlines()
# return log_lines
# def parse_log(log_lines):
# rounds = []
# clients = {}
# memory_usage = []
# round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)')
# client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
# memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
# current_round = None
# for line in log_lines:
# round_match = round_pattern.search(line)
# client_match = client_pattern.search(line)
# memory_match = memory_pattern.search(line)
# if round_match:
# current_round = int(round_match.group(1))
# rounds.append(current_round)
# elif client_match:
# client_id = int(client_match.group(1))
# log_level = client_match.group(2)
# message = client_match.group(3)
# if client_id not in clients:
# clients[client_id] = {'rounds': [], 'messages': []}
# clients[client_id]['rounds'].append(current_round)
# clients[client_id]['messages'].append((log_level, message))
# elif memory_match:
# memory_usage.append(float(memory_match.group(1)))
# return rounds, clients, memory_usage
# def plot_metrics(rounds, clients, memory_usage):
# st.write("## Metrics Overview")
# st.write("### Memory Usage")
# plt.figure()
# plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
# plt.xlabel('Step')
# plt.ylabel('Memory Usage (GB)')
# plt.legend()
# st.pyplot(plt)
# for client_id, data in clients.items():
# st.write(f"### Client {client_id} Metrics")
# info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
# debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
# st.write("#### INFO Messages")
# for msg in info_messages:
# st.write(msg)
# st.write("#### DEBUG Messages")
# for msg in debug_messages:
# st.write(msg)
# # Placeholder for actual loss and accuracy values, assuming they're included in the messages
# losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
# accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
# if losses:
# plt.figure()
# plt.plot(data['rounds'], losses, label='Loss')
# plt.xlabel('Round')
# plt.ylabel('Loss')
# plt.legend()
# st.pyplot(plt)
# if accuracies:
# plt.figure()
# plt.plot(data['rounds'], accuracies, label='Accuracy')
# plt.xlabel('Round')
# plt.ylabel('Accuracy')
# plt.legend()
# st.pyplot(plt)
# def read_log_file2():
# with open("./log.txt", "r") as file:
# return file.read()
# def main():
# st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
# dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
# model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
# NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
# NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
# train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS)
# trainloaders = []
# testloaders = []
# clients = []
# for i in range(NUM_CLIENTS):
# st.write(f"### Client {i+1} Datasets")
# train_df = pd.DataFrame(train_datasets[i])
# test_df = pd.DataFrame(test_datasets[i])
# st.write("#### Train Dataset (Words)")
# st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20)))
# st.write("#### Train Dataset (Tokens)")
# edited_train_df = st.data_editor(train_df, key=f"train_{i}")
# st.write("#### Test Dataset (Words)")
# st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20)))
# st.write("#### Test Dataset (Tokens)")
# edited_test_df = st.data_editor(test_df, key=f"test_{i}")
# edited_train_dataset = Dataset.from_pandas(edited_train_df)
# edited_test_dataset = Dataset.from_pandas(edited_test_df)
# trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
# testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
# trainloaders.append(trainloader)
# testloaders.append(testloader)
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
# client = CustomClient(net, trainloader, testloader, client_id=i+1)
# clients.append(client)
# if st.button("Start Training"):
# def client_fn(cid):
# return clients[int(cid)]
# def weighted_average(metrics):
# accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
# losses = [num_examples * m["loss"] for num_examples, m in metrics]
# examples = [num_examples for num_examples, _ in metrics]
# return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
# strategy = fl.server.strategy.FedAvg(
# fraction_fit=1.0,
# fraction_evaluate=1.0,
# evaluate_metrics_aggregation_fn=weighted_average,
# )
# for round_num in range(NUM_ROUNDS):
# st.write(f"### Round {round_num + 1} ✅")
# logs = read_log_file2()
# st.markdown(logs)
# # Extract relevant data
# accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
# loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
# accuracy_matches = accuracy_pattern.findall(logs)
# loss_matches = loss_pattern.findall(logs)
# rounds = [int(match[0]) for match in accuracy_matches]
# accuracies = [float(match[1]) for match in accuracy_matches]
# losses = [float(match[1]) for match in loss_matches]
# # Create accuracy plot
# accuracy_fig = go.Figure()
# accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
# accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
# # Create loss plot
# loss_fig = go.Figure()
# loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
# loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
# # Display plots in Streamlit
# st.plotly_chart(accuracy_fig)
# st.plotly_chart(loss_fig)
# # Display data table
# data = {
# 'Round': rounds,
# 'Accuracy': accuracies,
# 'Loss': losses
# }
# df = pd.DataFrame(data)
# st.write("## Training Metrics")
# st.table(df)
# plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
# fl.simulation.start_simulation(
# client_fn=client_fn,
# num_clients=NUM_CLIENTS,
# config=fl.server.ServerConfig(num_rounds=1),
# strategy=strategy,
# client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)},
# ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}
# )
# for i, client in enumerate(clients):
# client.plot_metrics(round_num + 1, plot_placeholders[i])
# st.write(" ")
# st.success("Training completed successfully!")
# # Display final metrics
# st.write("## Final Client Metrics")
# for client in clients:
# st.write(f"### Client {client.client_id}")
# if client.losses and client.accuracies:
# st.write(f"Final Loss: {client.losses[-1]:.4f}")
# st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
# client.plot_metrics(NUM_ROUNDS, st.empty())
# else:
# st.write("No metrics available.")
# st.write(" ")
# # Display log.txt content
# st.write("## Training Log")
# st.write(read_log_file2())
# st.write("## Training Log Analysis")
# log_lines = read_log_file()
# rounds, clients, memory_usage = parse_log(log_lines)
# plot_metrics(rounds, clients, memory_usage)
# else:
# st.write("Click the 'Start Training' button to start the training process.")
# if __name__ == "__main__":
# main()
# #################
import streamlit as st
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
from datasets import load_dataset, Dataset
from evaluate import load as load_metric
from torch.utils.data import DataLoader
import pandas as pd
import random
from collections import OrderedDict
import flwr as fl
from logging import INFO, DEBUG
from flwr.common.logger import log
import logging
import re
import plotly.graph_objects as go
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt")
def load_data(dataset_name, train_size=20, test_size=20, num_clients=2):
raw_datasets = load_dataset(dataset_name)
raw_datasets = raw_datasets.shuffle(seed=42)
del raw_datasets["unsupervised"]
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns("text")
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
train_datasets = []
test_datasets = []
for _ in range(num_clients):
train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size))
test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size))
train_datasets.append(train_dataset)
test_datasets.append(test_dataset)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
return train_datasets, test_datasets, data_collator, raw_datasets
def train(net, trainloader, epochs):
optimizer = AdamW(net.parameters(), lr=5e-5)
net.train()
for _ in range(epochs):
for batch in trainloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
outputs = net(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
def test(net, testloader):
metric = load_metric("accuracy")
net.eval()
loss = 0
for batch in testloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
with torch.no_grad():
outputs = net(**batch)
logits = outputs.logits
loss += outputs.loss.item()
predictions = torch.argmax(logits, dim=-1)
metric.add_batch(predictions=predictions, references=batch["labels"])
loss /= len(testloader)
accuracy = metric.compute()["accuracy"]
return loss, accuracy
class CustomClient(fl.client.NumPyClient):
def __init__(self, net, trainloader, testloader, client_id):
self.net = net
self.trainloader = trainloader
self.testloader = testloader
self.client_id = client_id
self.losses = []
self.accuracies = []
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
log(INFO, f"Client {self.client_id} is starting fit()")
self.set_parameters(parameters)
train(self.net, self.trainloader, epochs=1)
loss, accuracy = test(self.net, self.testloader)
self.losses.append(loss)
self.accuracies.append(accuracy)
log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy}
def evaluate(self, parameters, config):
log(INFO, f"Client {self.client_id} is starting evaluate()")
self.set_parameters(parameters)
loss, accuracy = test(self.net, self.testloader)
log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}")
return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)}
def plot_metrics(self, round_num, plot_placeholder):
if self.losses and self.accuracies:
plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}")
plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}")
plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}")
fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.set_xlabel('Round')
ax1.set_ylabel('Loss', color=color)
ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color)
ax1.tick_params(axis='y', labelcolor=color)
ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
color = 'tab:blue'
ax2.set_ylabel('Accuracy', color=color)
ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color)
ax2.tick_params(axis='y', labelcolor=color)
fig.tight_layout()
plot_placeholder.pyplot(fig)
def read_log_file(log_path='./log.txt'):
with open(log_path, 'r') as file:
log_lines = file.readlines()
return log_lines
def parse_log(log_lines):
rounds = []
clients = {}
memory_usage = []
round_pattern = re.compile(r'ROUND (\d+)')
client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)')
memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB')
current_round = None
for line in log_lines:
round_match = round_pattern.search(line)
client_match = client_pattern.search(line)
memory_match = memory_pattern.search(line)
if round_match:
current_round = int(round_match.group(1))
rounds.append(current_round)
elif client_match:
client_id = int(client_match.group(1))
log_level = client_match.group(2)
message = client_match.group(3)
if client_id not in clients:
clients[client_id] = {'rounds': [], 'messages': []}
clients[client_id]['rounds'].append(current_round)
clients[client_id]['messages'].append((log_level, message))
elif memory_match:
memory_usage.append(float(memory_match.group(1)))
return rounds, clients, memory_usage
def plot_metrics(rounds, clients, memory_usage):
st.write("## Metrics Overview")
st.write("### Memory Usage")
plt.figure()
plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)')
plt.xlabel('Step')
plt.ylabel('Memory Usage (GB)')
plt.legend()
st.pyplot(plt)
for client_id, data in clients.items():
st.write(f"### Client {client_id} Metrics")
info_messages = [msg for level, msg in data['messages'] if level == 'INFO']
debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG']
st.write("#### INFO Messages")
for msg in info_messages:
st.write(msg)
st.write("#### DEBUG Messages")
for msg in debug_messages:
st.write(msg)
# Placeholder for actual loss and accuracy values, assuming they're included in the messages
losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg]
accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg]
if losses:
plt.figure()
plt.plot(data['rounds'], losses, label='Loss')
plt.xlabel('Round')
plt.ylabel('Loss')
plt.legend()
st.pyplot(plt)
if accuracies:
plt.figure()
plt.plot(data['rounds'], accuracies, label='Accuracy')
plt.xlabel('Round')
plt.ylabel('Accuracy')
plt.legend()
st.pyplot(plt)
def read_log_file2():
with open("./log.txt", "r") as file:
return file.read()
def main():
st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices")
logs = read_log_file2()
# cleanLogs = # Define a pattern to match relevant log entries
pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE)
# Filter the log data
filtered_logs = [line for line in logs.splitlines() if pattern.search(line)]
st.markdown(filtered_logs)
# Provide a download button for the logs
st.download_button(
label="Download Logs",
data="\n".join(filtered_logs),
file_name="./log.txt",
mime="text/plain"
)
dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"])
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS)
trainloaders = []
testloaders = []
clients = []
for i in range(NUM_CLIENTS):
st.write(f"### Client {i+1} Datasets")
train_df = pd.DataFrame(train_datasets[i])
test_df = pd.DataFrame(test_datasets[i])
st.write("#### Train Dataset (Words)")
st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20)))
st.write("#### Train Dataset (Tokens)")
edited_train_df = st.data_editor(train_df, key=f"train_{i}")
st.write("#### Test Dataset (Words)")
st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20)))
st.write("#### Test Dataset (Tokens)")
edited_test_df = st.data_editor(test_df, key=f"test_{i}")
edited_train_dataset = Dataset.from_pandas(edited_train_df)
edited_test_dataset = Dataset.from_pandas(edited_test_df)
trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator)
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
trainloaders.append(trainloader)
testloaders.append(testloader)
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
client = CustomClient(net, trainloader, testloader, client_id=i+1)
clients.append(client)
if st.button("Start Training"):
def client_fn(cid):
return clients[int(cid)]
def weighted_average(metrics):
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
losses = [num_examples * m["loss"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0,
fraction_evaluate=1.0,
evaluate_metrics_aggregation_fn=weighted_average,
)
for round_num in range(NUM_ROUNDS):
st.write(f"### Round {round_num + 1} ✅")
logs = read_log_file2()
filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)]
filtered_logs = "\n".join(filtered_log_list)
st.markdown(filtered_logs)
# Provide a download button for the logs
# st.download_button(
# label="Download Logs",
# data=logs,
# file_name="./log.txt",
# mime="text/plain"
# )
# # Extract relevant data
accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}")
loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}")
accuracy_matches = accuracy_pattern.findall(filtered_logs)
loss_matches = loss_pattern.findall(filtered_logs)
rounds = [int(match[0]) for match in accuracy_matches]
accuracies = [float(match[1]) for match in accuracy_matches]
losses = [float(match[1]) for match in loss_matches]
# Create accuracy plot
accuracy_fig = go.Figure()
accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy'))
accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy')
# Create loss plot
loss_fig = go.Figure()
loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss'))
loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss')
# Display plots in Streamlit
st.plotly_chart(accuracy_fig)
st.plotly_chart(loss_fig)
# Display data table
data = {
'Round': rounds,
'Accuracy': accuracies,
'Loss': losses
}
df = pd.DataFrame(data)
st.write("## Training Metrics")
st.table(df)
plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)]
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=1),
strategy=strategy,
client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)},
ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}
)
for i, client in enumerate(clients):
client.plot_metrics(round_num + 1, plot_placeholders[i])
st.write(" ")
st.success("Training completed successfully!")
# Display final metrics
st.write("## Final Client Metrics")
for client in clients:
st.write(f"### Client {client.client_id}")
if client.losses and client.accuracies:
st.write(f"Final Loss: {client.losses[-1]:.4f}")
st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}")
client.plot_metrics(NUM_ROUNDS, st.empty())
else:
st.write("No metrics available.")
st.write(" ")
# Display log.txt content
st.write("## Training Log")
st.write(read_log_file2())
st.write("## Training Log Analysis")
log_lines = read_log_file()
rounds, clients, memory_usage = parse_log(log_lines)
plot_metrics(rounds, clients, memory_usage)
else:
st.write("Click the 'Start Training' button to start the training process.")
if __name__ == "__main__":
main()