Update app.py
Browse files
app.py
CHANGED
@@ -16,8 +16,13 @@ import zipfile
|
|
16 |
import math
|
17 |
from PIL import Image
|
18 |
import random
|
|
|
19 |
|
20 |
-
#
|
|
|
|
|
|
|
|
|
21 |
st.set_page_config(
|
22 |
page_title="SFT Tiny Titans 🚀",
|
23 |
page_icon="🤖",
|
@@ -42,7 +47,7 @@ class ModelConfig:
|
|
42 |
def model_path(self):
|
43 |
return f"models/{self.name}"
|
44 |
|
45 |
-
# Custom Dataset for SFT
|
46 |
class SFTDataset(Dataset):
|
47 |
def __init__(self, data, tokenizer, max_length=128):
|
48 |
self.data = data
|
@@ -56,7 +61,6 @@ class SFTDataset(Dataset):
|
|
56 |
prompt = self.data[idx]["prompt"]
|
57 |
response = self.data[idx]["response"]
|
58 |
|
59 |
-
# Tokenize the full sequence once
|
60 |
full_text = f"{prompt} {response}"
|
61 |
full_encoding = self.tokenizer(
|
62 |
full_text,
|
@@ -66,23 +70,21 @@ class SFTDataset(Dataset):
|
|
66 |
return_tensors="pt"
|
67 |
)
|
68 |
|
69 |
-
# Tokenize prompt separately to get its length
|
70 |
prompt_encoding = self.tokenizer(
|
71 |
prompt,
|
72 |
max_length=self.max_length,
|
73 |
-
padding=False,
|
74 |
truncation=True,
|
75 |
return_tensors="pt"
|
76 |
)
|
77 |
|
78 |
input_ids = full_encoding["input_ids"].squeeze()
|
79 |
attention_mask = full_encoding["attention_mask"].squeeze()
|
80 |
-
labels = input_ids.clone()
|
81 |
|
82 |
-
|
83 |
-
prompt_len = prompt_encoding["input_ids"].shape[1] # Actual length of prompt
|
84 |
if prompt_len < self.max_length:
|
85 |
-
labels[:prompt_len] = -100
|
86 |
|
87 |
return {
|
88 |
"input_ids": input_ids,
|
@@ -133,7 +135,6 @@ class ModelBuilder:
|
|
133 |
attention_mask = batch["attention_mask"].to(device)
|
134 |
labels = batch["labels"].to(device)
|
135 |
|
136 |
-
# Debug shapes
|
137 |
assert input_ids.shape[0] == labels.shape[0], f"Batch size mismatch: input_ids {input_ids.shape}, labels {labels.shape}"
|
138 |
|
139 |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
@@ -152,12 +153,37 @@ class ModelBuilder:
|
|
152 |
self.tokenizer.save_pretrained(path)
|
153 |
st.success(f"Model saved at {path}! ✅ May the force be with it.")
|
154 |
|
155 |
-
def evaluate(self, prompt: str):
|
|
|
156 |
self.model.eval()
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
# Utility Functions with Wit
|
163 |
def get_download_link(file_path, mime_type="text/plain", label="Download"):
|
@@ -209,14 +235,14 @@ img_files = get_gallery_files(["png", "jpg", "jpeg"])
|
|
209 |
if img_files:
|
210 |
img_cols = st.sidebar.slider("Image Columns 📸", 1, 5, 3)
|
211 |
cols = st.sidebar.columns(img_cols)
|
212 |
-
for idx, img_file in enumerate(img_files[:img_cols * 2]):
|
213 |
with cols[idx % img_cols]:
|
214 |
st.image(Image.open(img_file), caption=f"{img_file} 🖼", use_column_width=True)
|
215 |
|
216 |
st.sidebar.subheader("CSV Gallery 📊")
|
217 |
csv_files = get_gallery_files(["csv"])
|
218 |
if csv_files:
|
219 |
-
for csv_file in csv_files[:5]:
|
220 |
st.sidebar.markdown(get_download_link(csv_file, "text/csv", f"{csv_file} 📊"), unsafe_allow_html=True)
|
221 |
|
222 |
st.sidebar.subheader("Model Management 🗂️")
|
@@ -302,19 +328,25 @@ with tab3:
|
|
302 |
else:
|
303 |
if st.session_state['builder'].sft_data:
|
304 |
st.write("Testing with SFT Data:")
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
313 |
|
314 |
test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
|
315 |
if st.button("Run Test ▶️"):
|
316 |
-
|
317 |
-
|
|
|
|
|
|
|
318 |
|
319 |
if st.button("Export Titan Files 📦"):
|
320 |
config = st.session_state['builder'].config
|
@@ -361,11 +393,9 @@ with tab4:
|
|
361 |
try:
|
362 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool
|
363 |
|
364 |
-
# Load a tiny model (default to SmolLM-135M for speed)
|
365 |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
366 |
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
367 |
|
368 |
-
# Define Agentic RAG agent with a witty twist
|
369 |
agent = CodeAgent(
|
370 |
model=model,
|
371 |
tokenizer=tokenizer,
|
|
|
16 |
import math
|
17 |
from PIL import Image
|
18 |
import random
|
19 |
+
import logging
|
20 |
|
21 |
+
# Set up logging for feedback
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
# Page Configuration with Humor
|
26 |
st.set_page_config(
|
27 |
page_title="SFT Tiny Titans 🚀",
|
28 |
page_icon="🤖",
|
|
|
47 |
def model_path(self):
|
48 |
return f"models/{self.name}"
|
49 |
|
50 |
+
# Custom Dataset for SFT
|
51 |
class SFTDataset(Dataset):
|
52 |
def __init__(self, data, tokenizer, max_length=128):
|
53 |
self.data = data
|
|
|
61 |
prompt = self.data[idx]["prompt"]
|
62 |
response = self.data[idx]["response"]
|
63 |
|
|
|
64 |
full_text = f"{prompt} {response}"
|
65 |
full_encoding = self.tokenizer(
|
66 |
full_text,
|
|
|
70 |
return_tensors="pt"
|
71 |
)
|
72 |
|
|
|
73 |
prompt_encoding = self.tokenizer(
|
74 |
prompt,
|
75 |
max_length=self.max_length,
|
76 |
+
padding=False,
|
77 |
truncation=True,
|
78 |
return_tensors="pt"
|
79 |
)
|
80 |
|
81 |
input_ids = full_encoding["input_ids"].squeeze()
|
82 |
attention_mask = full_encoding["attention_mask"].squeeze()
|
83 |
+
labels = input_ids.clone()
|
84 |
|
85 |
+
prompt_len = prompt_encoding["input_ids"].shape[1]
|
|
|
86 |
if prompt_len < self.max_length:
|
87 |
+
labels[:prompt_len] = -100
|
88 |
|
89 |
return {
|
90 |
"input_ids": input_ids,
|
|
|
135 |
attention_mask = batch["attention_mask"].to(device)
|
136 |
labels = batch["labels"].to(device)
|
137 |
|
|
|
138 |
assert input_ids.shape[0] == labels.shape[0], f"Batch size mismatch: input_ids {input_ids.shape}, labels {labels.shape}"
|
139 |
|
140 |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
|
|
153 |
self.tokenizer.save_pretrained(path)
|
154 |
st.success(f"Model saved at {path}! ✅ May the force be with it.")
|
155 |
|
156 |
+
def evaluate(self, prompt: str, status_container=None):
|
157 |
+
"""Evaluate with feedback"""
|
158 |
self.model.eval()
|
159 |
+
if status_container:
|
160 |
+
status_container.write("Preparing to evaluate... 🧠 (Titan’s warming up its circuits!)")
|
161 |
+
logger.info(f"Evaluating prompt: {prompt}")
|
162 |
+
|
163 |
+
try:
|
164 |
+
with torch.no_grad():
|
165 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
|
166 |
+
if status_container:
|
167 |
+
status_container.write(f"Tokenized input shape: {inputs['input_ids'].shape} 📏")
|
168 |
+
|
169 |
+
outputs = self.model.generate(
|
170 |
+
**inputs,
|
171 |
+
max_new_tokens=50,
|
172 |
+
do_sample=True,
|
173 |
+
top_p=0.95,
|
174 |
+
temperature=0.7
|
175 |
+
)
|
176 |
+
if status_container:
|
177 |
+
status_container.write("Generation complete! Decoding response... 🗣")
|
178 |
+
|
179 |
+
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
180 |
+
logger.info(f"Generated response: {result}")
|
181 |
+
return result
|
182 |
+
except Exception as e:
|
183 |
+
logger.error(f"Evaluation error: {str(e)}")
|
184 |
+
if status_container:
|
185 |
+
status_container.error(f"Oops! Something broke: {str(e)} 💥 (Titan tripped over a wire!)")
|
186 |
+
return f"Error: {str(e)}"
|
187 |
|
188 |
# Utility Functions with Wit
|
189 |
def get_download_link(file_path, mime_type="text/plain", label="Download"):
|
|
|
235 |
if img_files:
|
236 |
img_cols = st.sidebar.slider("Image Columns 📸", 1, 5, 3)
|
237 |
cols = st.sidebar.columns(img_cols)
|
238 |
+
for idx, img_file in enumerate(img_files[:img_cols * 2]):
|
239 |
with cols[idx % img_cols]:
|
240 |
st.image(Image.open(img_file), caption=f"{img_file} 🖼", use_column_width=True)
|
241 |
|
242 |
st.sidebar.subheader("CSV Gallery 📊")
|
243 |
csv_files = get_gallery_files(["csv"])
|
244 |
if csv_files:
|
245 |
+
for csv_file in csv_files[:5]:
|
246 |
st.sidebar.markdown(get_download_link(csv_file, "text/csv", f"{csv_file} 📊"), unsafe_allow_html=True)
|
247 |
|
248 |
st.sidebar.subheader("Model Management 🗂️")
|
|
|
328 |
else:
|
329 |
if st.session_state['builder'].sft_data:
|
330 |
st.write("Testing with SFT Data:")
|
331 |
+
with st.spinner("Running SFT data tests... ⏳ (Titan’s flexing its brain muscles!)"):
|
332 |
+
for item in st.session_state['builder'].sft_data[:3]:
|
333 |
+
prompt = item["prompt"]
|
334 |
+
expected = item["response"]
|
335 |
+
status_container = st.empty()
|
336 |
+
generated = st.session_state['builder'].evaluate(prompt, status_container)
|
337 |
+
st.write(f"**Prompt**: {prompt}")
|
338 |
+
st.write(f"**Expected**: {expected}")
|
339 |
+
st.write(f"**Generated**: {generated} (Titan says: '{random.choice(['Bleep bloop!', 'I am groot!', '42!'])}')")
|
340 |
+
st.write("---")
|
341 |
+
status_container.empty() # Clear status after each test
|
342 |
|
343 |
test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
|
344 |
if st.button("Run Test ▶️"):
|
345 |
+
with st.spinner("Testing your prompt... ⏳ (Titan’s pondering deeply!)"):
|
346 |
+
status_container = st.empty()
|
347 |
+
result = st.session_state['builder'].evaluate(test_prompt, status_container)
|
348 |
+
st.write(f"**Generated Response**: {result} (Titan’s wisdom unleashed!)")
|
349 |
+
status_container.empty()
|
350 |
|
351 |
if st.button("Export Titan Files 📦"):
|
352 |
config = st.session_state['builder'].config
|
|
|
393 |
try:
|
394 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool
|
395 |
|
|
|
396 |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
397 |
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")
|
398 |
|
|
|
399 |
agent = CodeAgent(
|
400 |
model=model,
|
401 |
tokenizer=tokenizer,
|