Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
Requirements:
|
5 |
+
streamlit
|
6 |
+
torch
|
7 |
+
pandas
|
8 |
+
transformers
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import streamlit as st
|
13 |
+
import torch
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
15 |
+
from torch.utils.data import Dataset, DataLoader
|
16 |
+
import csv
|
17 |
+
|
18 |
+
# Page Configuration
|
19 |
+
st.set_page_config(
|
20 |
+
page_title="SFT Model Builder π",
|
21 |
+
page_icon="π€",
|
22 |
+
layout="wide",
|
23 |
+
initial_sidebar_state="expanded",
|
24 |
+
)
|
25 |
+
|
26 |
+
# Help Documentation as a Variable
|
27 |
+
HELP_DOC = """
|
28 |
+
# SFT Model Builder - Help Guide π
|
29 |
+
|
30 |
+
## Overview
|
31 |
+
This Streamlit app allows users to **download, fine-tune, and test Transformer models** with **Supervised Fine-Tuning (SFT)** using CSV data. It is designed for NLP tasks and can be expanded for **CV and Speech models**.
|
32 |
+
|
33 |
+
## Features
|
34 |
+
- β
**Download a pre-trained model** from Hugging Face.
|
35 |
+
- β
**Upload a CSV dataset** for fine-tuning.
|
36 |
+
- β
**Train the model** with multiple epochs and adjustable batch sizes.
|
37 |
+
- β
**Test the fine-tuned model** with text prompts.
|
38 |
+
|
39 |
+
## Installation
|
40 |
+
To run the app, install dependencies:
|
41 |
+
```bash
|
42 |
+
pip install -r requirements.txt
|
43 |
+
```
|
44 |
+
Then, start the app:
|
45 |
+
```bash
|
46 |
+
streamlit run app.py
|
47 |
+
```
|
48 |
+
|
49 |
+
## How to Use
|
50 |
+
1. **Download Model**: Select a base model (e.g., `distilgpt2`), then click **Download Model**.
|
51 |
+
2. **Upload CSV**: The CSV must have two columns: `prompt` and `response`.
|
52 |
+
3. **Fine-Tune Model**: Click **Fine-Tune Model** to start training.
|
53 |
+
4. **Test Model**: Enter a text prompt and generate responses.
|
54 |
+
|
55 |
+
## CSV Format
|
56 |
+
Example format:
|
57 |
+
```csv
|
58 |
+
prompt,response
|
59 |
+
"What is AI?","AI is artificial intelligence."
|
60 |
+
"Explain machine learning","Machine learning is a subset of AI."
|
61 |
+
```
|
62 |
+
|
63 |
+
## Model References
|
64 |
+
| Model π | Description π | Link π |
|
65 |
+
|---------|-------------|---------|
|
66 |
+
| **GPT-2** π€ | Standard NLP model | [Hugging Face](https://huggingface.co/gpt2) |
|
67 |
+
| **DistilGPT-2** β‘ | Lightweight version of GPT-2 | [Hugging Face](https://huggingface.co/distilgpt2) |
|
68 |
+
| **EleutherAI Pythia** π¬ | Open-source GPT-like models | [Hugging Face](https://huggingface.co/EleutherAI/pythia-70m) |
|
69 |
+
|
70 |
+
## Additional Notes
|
71 |
+
- This app supports **PyTorch models**.
|
72 |
+
- Default training parameters: `epochs=3`, `batch_size=4`.
|
73 |
+
- Fine-tuned models are **saved locally** for future use.
|
74 |
+
|
75 |
+
For more details, visit [Hugging Face Models](https://huggingface.co/models). π
|
76 |
+
"""
|
77 |
+
|
78 |
+
# Custom Dataset for Fine-Tuning
|
79 |
+
class SFTDataset(Dataset):
|
80 |
+
def __init__(self, data, tokenizer, max_length=128):
|
81 |
+
self.data = data
|
82 |
+
self.tokenizer = tokenizer
|
83 |
+
self.max_length = max_length
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
return len(self.data)
|
87 |
+
|
88 |
+
def __getitem__(self, idx):
|
89 |
+
prompt = self.data[idx]["prompt"]
|
90 |
+
response = self.data[idx]["response"]
|
91 |
+
input_text = f"{prompt} {response}"
|
92 |
+
encoding = self.tokenizer(
|
93 |
+
input_text,
|
94 |
+
max_length=self.max_length,
|
95 |
+
padding="max_length",
|
96 |
+
truncation=True,
|
97 |
+
return_tensors="pt"
|
98 |
+
)
|
99 |
+
return {
|
100 |
+
"input_ids": encoding["input_ids"].squeeze(),
|
101 |
+
"attention_mask": encoding["attention_mask"].squeeze(),
|
102 |
+
"labels": encoding["input_ids"].squeeze()
|
103 |
+
}
|
104 |
+
|
105 |
+
# Model Loader and Trainer Class
|
106 |
+
class ModelBuilder:
|
107 |
+
def __init__(self, model_name="distilgpt2"):
|
108 |
+
self.model_name = model_name
|
109 |
+
self.model = None
|
110 |
+
self.tokenizer = None
|
111 |
+
|
112 |
+
def load_model(self):
|
113 |
+
st.spinner("Loading model... β³")
|
114 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
|
115 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
116 |
+
if self.tokenizer.pad_token is None:
|
117 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
118 |
+
st.success("Model loaded! β
")
|
119 |
+
|
120 |
+
def fine_tune(self, csv_path, epochs=3, batch_size=4):
|
121 |
+
"""Supervised Fine-Tuning with CSV data"""
|
122 |
+
sft_data = []
|
123 |
+
with open(csv_path, "r") as f:
|
124 |
+
reader = csv.DictReader(f)
|
125 |
+
for row in reader:
|
126 |
+
sft_data.append({"prompt": row["prompt"], "response": row["response"]})
|
127 |
+
|
128 |
+
dataset = SFTDataset(sft_data, self.tokenizer)
|
129 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
130 |
+
optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
|
131 |
+
|
132 |
+
self.model.train()
|
133 |
+
for epoch in range(epochs):
|
134 |
+
st.spinner(f"Training epoch {epoch + 1}/{epochs}... βοΈ")
|
135 |
+
for batch in dataloader:
|
136 |
+
optimizer.zero_grad()
|
137 |
+
input_ids = batch["input_ids"].to(self.model.device)
|
138 |
+
attention_mask = batch["attention_mask"].to(self.model.device)
|
139 |
+
labels = batch["labels"].to(self.model.device)
|
140 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
141 |
+
loss = outputs.loss
|
142 |
+
loss.backward()
|
143 |
+
optimizer.step()
|
144 |
+
st.write(f"Epoch {epoch + 1} completed.")
|
145 |
+
st.success("Fine-tuning completed! π")
|
146 |
+
|
147 |
+
# Main UI
|
148 |
+
st.title("SFT Model Builder π€π")
|
149 |
+
model_builder = ModelBuilder()
|
150 |
+
|
151 |
+
if st.button("Download Model β¬οΈ"):
|
152 |
+
model_builder.load_model()
|
153 |
+
|
154 |
+
csv_file = st.file_uploader("Upload CSV for Fine-Tuning", type="csv")
|
155 |
+
if csv_file and st.button("Fine-Tune Model π"):
|
156 |
+
model_builder.fine_tune(csv_file)
|
157 |
+
|
158 |
+
# Render Help Documentation at End
|
159 |
+
st.markdown(HELP_DOC)
|