File size: 4,616 Bytes
4da9684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# %%writefile app.py

import streamlit as st
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
from datasets import load_dataset
from evaluate import load as load_metric
from torch.utils.data import DataLoader
import random

DEVICE = torch.device("cpu")
NUM_ROUNDS = 3

def load_data(dataset_name):
    raw_datasets = load_dataset(dataset_name)
    raw_datasets = raw_datasets.shuffle(seed=42)
    del raw_datasets["unsupervised"]

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    def tokenize_function(examples):
        return tokenizer(examples["text"], truncation=True)

    train_population = random.sample(range(len(raw_datasets["train"])), 20)
    test_population = random.sample(range(len(raw_datasets["test"])), 20)

    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
    tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population)
    tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population)

    tokenized_datasets = tokenized_datasets.remove_columns("text")
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    trainloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=32, collate_fn=data_collator)
    testloader = DataLoader(tokenized_datasets["test"], batch_size=32, collate_fn=data_collator)

    return trainloader, testloader

def train(net, trainloader, epochs):
    optimizer = AdamW(net.parameters(), lr=5e-5)
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            outputs = net(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

def test(net, testloader):
    metric = load_metric("accuracy")
    loss = 0
    net.eval()
    for batch in testloader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with torch.no_grad():
            outputs = net(**batch)
        logits = outputs.logits
        loss += outputs.loss.item()
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    loss /= len(testloader.dataset)
    accuracy = metric.compute()["accuracy"]
    return loss, accuracy

net = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2).to(DEVICE)

def main():
    st.write("## Federated Learning with dynamic models and datasets for mobile devices")
    dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
    model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
    NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
    NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)

    trainloader, testloader = load_data(dataset_name)

    if st.button("Start Training"):
        round_losses = []
        round_accuracies = []  # Store accuracy values for each round
        for round_num in range(1, NUM_ROUNDS + 1):
            st.write(f"## Round {round_num}")

            st.write("### Training Metrics for Each Client")
            for client in range(1, NUM_CLIENTS + 1):
                client_loss, client_accuracy = test(net, testloader)  # Placeholder for actual client metrics
                st.write(f"Client {client}: Loss: {client_loss}, Accuracy: {client_accuracy}")

            st.write("### Accuracy Over Rounds")
            round_accuracies.append(client_accuracy)  # Append the accuracy for this round
            plt.plot(range(1, round_num + 1), round_accuracies, marker='o')  # Plot accuracy over rounds
            plt.xlabel("Round")
            plt.ylabel("Accuracy")
            plt.title("Accuracy Over Rounds")
            st.pyplot()

            st.write("### Loss Over Rounds")
            loss_value = random.random()  # Placeholder for loss values
            round_losses.append(loss_value)
            rounds = list(range(1, round_num + 1))
            plt.plot(rounds, round_losses)
            plt.xlabel("Round")
            plt.ylabel("Loss")
            plt.title("Loss Over Rounds")
            st.pyplot()

            st.success(f"Round {round_num} completed successfully!")

    else:
        st.write("Click the 'Start Training' button to start the training process.")

if __name__ == "__main__":
    main()