|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import streamlit as st |
|
import matplotlib.pyplot as plt |
|
import torch |
|
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW |
|
from datasets import load_dataset, Dataset |
|
from evaluate import load as load_metric |
|
from torch.utils.data import DataLoader |
|
import pandas as pd |
|
import random |
|
from collections import OrderedDict |
|
import flwr as fl |
|
from logging import INFO, DEBUG |
|
from flwr.common.logger import log |
|
import logging |
|
import streamlit |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt") |
|
|
|
def load_data(dataset_name, train_size=20, test_size=20, num_clients=2): |
|
raw_datasets = load_dataset(dataset_name) |
|
raw_datasets = raw_datasets.shuffle(seed=42) |
|
del raw_datasets["unsupervised"] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples["text"], truncation=True) |
|
|
|
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) |
|
tokenized_datasets = tokenized_datasets.remove_columns("text") |
|
tokenized_datasets = tokenized_datasets.rename_column("label", "labels") |
|
|
|
train_datasets = [] |
|
test_datasets = [] |
|
|
|
for _ in range(num_clients): |
|
train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size)) |
|
test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size)) |
|
train_datasets.append(train_dataset) |
|
test_datasets.append(test_dataset) |
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
|
return train_datasets, test_datasets, data_collator, raw_datasets |
|
|
|
def train(net, trainloader, epochs): |
|
optimizer = AdamW(net.parameters(), lr=5e-5) |
|
net.train() |
|
for _ in range(epochs): |
|
for batch in trainloader: |
|
batch = {k: v.to(DEVICE) for k, v in batch.items()} |
|
outputs = net(**batch) |
|
loss = outputs.loss |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
def test(net, testloader): |
|
metric = load_metric("accuracy") |
|
net.eval() |
|
loss = 0 |
|
for batch in testloader: |
|
batch = {k: v.to(DEVICE) for k, v in batch.items()} |
|
with torch.no_grad(): |
|
outputs = net(**batch) |
|
logits = outputs.logits |
|
loss += outputs.loss.item() |
|
predictions = torch.argmax(logits, dim=-1) |
|
metric.add_batch(predictions=predictions, references=batch["labels"]) |
|
loss /= len(testloader) |
|
accuracy = metric.compute()["accuracy"] |
|
return loss, accuracy |
|
|
|
class CustomClient(fl.client.NumPyClient): |
|
def __init__(self, net, trainloader, testloader, client_id): |
|
self.net = net |
|
self.trainloader = trainloader |
|
self.testloader = testloader |
|
self.client_id = client_id |
|
self.losses = [] |
|
self.accuracies = [] |
|
|
|
def get_parameters(self, config): |
|
return [val.cpu().numpy() for _, val in self.net.state_dict().items()] |
|
|
|
def set_parameters(self, parameters): |
|
params_dict = zip(self.net.state_dict().keys(), parameters) |
|
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) |
|
self.net.load_state_dict(state_dict, strict=True) |
|
|
|
def fit(self, parameters, config): |
|
log(INFO, f"Client {self.client_id} is starting fit()") |
|
self.set_parameters(parameters) |
|
train(self.net, self.trainloader, epochs=1) |
|
loss, accuracy = test(self.net, self.testloader) |
|
self.losses.append(loss) |
|
self.accuracies.append(accuracy) |
|
log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") |
|
return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy} |
|
|
|
def evaluate(self, parameters, config): |
|
log(INFO, f"Client {self.client_id} is starting evaluate()") |
|
self.set_parameters(parameters) |
|
loss, accuracy = test(self.net, self.testloader) |
|
log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") |
|
return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)} |
|
|
|
def plot_metrics(self, round_num, plot_placeholder): |
|
if self.losses and self.accuracies: |
|
plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}") |
|
plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}") |
|
plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}") |
|
|
|
fig, ax1 = plt.subplots() |
|
|
|
color = 'tab:red' |
|
ax1.set_xlabel('Round') |
|
ax1.set_ylabel('Loss', color=color) |
|
ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color) |
|
ax1.tick_params(axis='y', labelcolor=color) |
|
|
|
ax2 = ax1.twinx() |
|
color = 'tab:blue' |
|
ax2.set_ylabel('Accuracy', color=color) |
|
ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color) |
|
ax2.tick_params(axis='y', labelcolor=color) |
|
|
|
fig.tight_layout() |
|
plot_placeholder.pyplot(fig) |
|
import matplotlib.pyplot as plt |
|
import re |
|
|
|
def read_log_file(log_path='./log.txt'): |
|
with open(log_path, 'r') as file: |
|
log_lines = file.readlines() |
|
return log_lines |
|
|
|
def parse_log(log_lines): |
|
rounds = [] |
|
clients = {} |
|
memory_usage = [] |
|
|
|
round_pattern = re.compile(r'ROUND(\d+)ROUND (\d+)') |
|
client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)') |
|
memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB') |
|
|
|
current_round = None |
|
|
|
for line in log_lines: |
|
round_match = round_pattern.search(line) |
|
client_match = client_pattern.search(line) |
|
memory_match = memory_pattern.search(line) |
|
|
|
if round_match: |
|
current_round = int(round_match.group(1)) |
|
rounds.append(current_round) |
|
elif client_match: |
|
client_id = int(client_match.group(1)) |
|
log_level = client_match.group(2) |
|
message = client_match.group(3) |
|
|
|
if client_id not in clients: |
|
clients[client_id] = {'rounds': [], 'messages': []} |
|
|
|
clients[client_id]['rounds'].append(current_round) |
|
clients[client_id]['messages'].append((log_level, message)) |
|
elif memory_match: |
|
memory_usage.append(float(memory_match.group(1))) |
|
|
|
return rounds, clients, memory_usage |
|
|
|
def plot_metrics(rounds, clients, memory_usage): |
|
st.write("## Metrics Overview") |
|
|
|
st.write("### Memory Usage") |
|
plt.figure() |
|
plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)') |
|
plt.xlabel('Step') |
|
plt.ylabel('Memory Usage (GB)') |
|
plt.legend() |
|
st.pyplot(plt) |
|
|
|
for client_id, data in clients.items(): |
|
st.write(f"### Client {client_id} Metrics") |
|
|
|
info_messages = [msg for level, msg in data['messages'] if level == 'INFO'] |
|
debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG'] |
|
|
|
st.write("#### INFO Messages") |
|
for msg in info_messages: |
|
st.write(msg) |
|
|
|
st.write("#### DEBUG Messages") |
|
for msg in debug_messages: |
|
st.write(msg) |
|
|
|
|
|
losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg] |
|
accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg] |
|
|
|
if losses: |
|
plt.figure() |
|
plt.plot(data['rounds'], losses, label='Loss') |
|
plt.xlabel('Round') |
|
plt.ylabel('Loss') |
|
plt.legend() |
|
st.pyplot(plt) |
|
|
|
if accuracies: |
|
plt.figure() |
|
plt.plot(data['rounds'], accuracies, label='Accuracy') |
|
plt.xlabel('Round') |
|
plt.ylabel('Accuracy') |
|
plt.legend() |
|
st.pyplot(plt) |
|
|
|
|
|
def read_log_file2(): |
|
with open("./log.txt", "r") as file: |
|
return file.read() |
|
|
|
def main(): |
|
st.markdown(print(streamlit.logger._loggers)) |
|
st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices") |
|
dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"]) |
|
model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"]) |
|
|
|
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2) |
|
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3) |
|
|
|
train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS) |
|
|
|
trainloaders = [] |
|
testloaders = [] |
|
clients = [] |
|
|
|
for i in range(NUM_CLIENTS): |
|
st.write(f"### Client {i+1} Datasets") |
|
|
|
train_df = pd.DataFrame(train_datasets[i]) |
|
test_df = pd.DataFrame(test_datasets[i]) |
|
|
|
st.write("#### Train Dataset (Words)") |
|
st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20))) |
|
st.write("#### Train Dataset (Tokens)") |
|
edited_train_df = st.data_editor(train_df, key=f"train_{i}") |
|
|
|
st.write("#### Test Dataset (Words)") |
|
st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20))) |
|
st.write("#### Test Dataset (Tokens)") |
|
edited_test_df = st.data_editor(test_df, key=f"test_{i}") |
|
|
|
edited_train_dataset = Dataset.from_pandas(edited_train_df) |
|
edited_test_dataset = Dataset.from_pandas(edited_test_df) |
|
|
|
trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator) |
|
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator) |
|
|
|
trainloaders.append(trainloader) |
|
testloaders.append(testloader) |
|
|
|
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE) |
|
client = CustomClient(net, trainloader, testloader, client_id=i+1) |
|
clients.append(client) |
|
|
|
if st.button("Start Training"): |
|
def client_fn(cid): |
|
return clients[int(cid)] |
|
|
|
def weighted_average(metrics): |
|
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] |
|
losses = [num_examples * m["loss"] for num_examples, m in metrics] |
|
examples = [num_examples for num_examples, _ in metrics] |
|
return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} |
|
|
|
strategy = fl.server.strategy.FedAvg( |
|
fraction_fit=1.0, |
|
fraction_evaluate=1.0, |
|
evaluate_metrics_aggregation_fn=weighted_average, |
|
) |
|
|
|
for round_num in range(NUM_ROUNDS): |
|
st.write(f"### Round {round_num + 1}") |
|
st.markdown(read_log_file2()) |
|
plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)] |
|
|
|
fl.simulation.start_simulation( |
|
client_fn=client_fn, |
|
num_clients=NUM_CLIENTS, |
|
config=fl.server.ServerConfig(num_rounds=1), |
|
strategy=strategy, |
|
client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}, |
|
ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)} |
|
) |
|
|
|
for i, client in enumerate(clients): |
|
client.plot_metrics(round_num + 1, plot_placeholders[i]) |
|
st.write(" ") |
|
|
|
st.success("Training completed successfully!") |
|
|
|
|
|
st.write("## Final Client Metrics") |
|
for client in clients: |
|
st.write(f"### Client {client.client_id}") |
|
if client.losses and client.accuracies: |
|
st.write(f"Final Loss: {client.losses[-1]:.4f}") |
|
st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}") |
|
client.plot_metrics(NUM_ROUNDS, st.empty()) |
|
else: |
|
st.write("No metrics available.") |
|
|
|
st.write(" ") |
|
|
|
|
|
st.write("## Training Log") |
|
|
|
st.write("## Training Log Analysis") |
|
|
|
log_lines = read_log_file() |
|
rounds, clients, memory_usage = parse_log(log_lines) |
|
|
|
plot_metrics(rounds, clients, memory_usage) |
|
|
|
else: |
|
st.write("Click the 'Start Training' button to start the training process.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|