awacke1 commited on
Commit
5858718
·
verified ·
1 Parent(s): b4eeb2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -29
app.py CHANGED
@@ -16,8 +16,13 @@ import zipfile
16
  import math
17
  from PIL import Image
18
  import random
 
19
 
20
- # Page Configuration with a Dash of Humor
 
 
 
 
21
  st.set_page_config(
22
  page_title="SFT Tiny Titans 🚀",
23
  page_icon="🤖",
@@ -42,7 +47,7 @@ class ModelConfig:
42
  def model_path(self):
43
  return f"models/{self.name}"
44
 
45
- # Custom Dataset for SFT (Fixed)
46
  class SFTDataset(Dataset):
47
  def __init__(self, data, tokenizer, max_length=128):
48
  self.data = data
@@ -56,7 +61,6 @@ class SFTDataset(Dataset):
56
  prompt = self.data[idx]["prompt"]
57
  response = self.data[idx]["response"]
58
 
59
- # Tokenize the full sequence once
60
  full_text = f"{prompt} {response}"
61
  full_encoding = self.tokenizer(
62
  full_text,
@@ -66,23 +70,21 @@ class SFTDataset(Dataset):
66
  return_tensors="pt"
67
  )
68
 
69
- # Tokenize prompt separately to get its length
70
  prompt_encoding = self.tokenizer(
71
  prompt,
72
  max_length=self.max_length,
73
- padding=False, # No padding here, just to get length
74
  truncation=True,
75
  return_tensors="pt"
76
  )
77
 
78
  input_ids = full_encoding["input_ids"].squeeze()
79
  attention_mask = full_encoding["attention_mask"].squeeze()
80
- labels = input_ids.clone() # Clone to avoid modifying input_ids
81
 
82
- # Mask prompt tokens in labels
83
- prompt_len = prompt_encoding["input_ids"].shape[1] # Actual length of prompt
84
  if prompt_len < self.max_length:
85
- labels[:prompt_len] = -100 # Ignore prompt in loss
86
 
87
  return {
88
  "input_ids": input_ids,
@@ -133,7 +135,6 @@ class ModelBuilder:
133
  attention_mask = batch["attention_mask"].to(device)
134
  labels = batch["labels"].to(device)
135
 
136
- # Debug shapes
137
  assert input_ids.shape[0] == labels.shape[0], f"Batch size mismatch: input_ids {input_ids.shape}, labels {labels.shape}"
138
 
139
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
@@ -152,12 +153,37 @@ class ModelBuilder:
152
  self.tokenizer.save_pretrained(path)
153
  st.success(f"Model saved at {path}! ✅ May the force be with it.")
154
 
155
- def evaluate(self, prompt: str):
 
156
  self.model.eval()
157
- with torch.no_grad():
158
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
159
- outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
160
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  # Utility Functions with Wit
163
  def get_download_link(file_path, mime_type="text/plain", label="Download"):
@@ -209,14 +235,14 @@ img_files = get_gallery_files(["png", "jpg", "jpeg"])
209
  if img_files:
210
  img_cols = st.sidebar.slider("Image Columns 📸", 1, 5, 3)
211
  cols = st.sidebar.columns(img_cols)
212
- for idx, img_file in enumerate(img_files[:img_cols * 2]): # Limit to 2 rows
213
  with cols[idx % img_cols]:
214
  st.image(Image.open(img_file), caption=f"{img_file} 🖼", use_column_width=True)
215
 
216
  st.sidebar.subheader("CSV Gallery 📊")
217
  csv_files = get_gallery_files(["csv"])
218
  if csv_files:
219
- for csv_file in csv_files[:5]: # Limit to 5
220
  st.sidebar.markdown(get_download_link(csv_file, "text/csv", f"{csv_file} 📊"), unsafe_allow_html=True)
221
 
222
  st.sidebar.subheader("Model Management 🗂️")
@@ -302,19 +328,25 @@ with tab3:
302
  else:
303
  if st.session_state['builder'].sft_data:
304
  st.write("Testing with SFT Data:")
305
- for item in st.session_state['builder'].sft_data[:3]:
306
- prompt = item["prompt"]
307
- expected = item["response"]
308
- generated = st.session_state['builder'].evaluate(prompt)
309
- st.write(f"**Prompt**: {prompt}")
310
- st.write(f"**Expected**: {expected}")
311
- st.write(f"**Generated**: {generated} (Titan says: '{random.choice(['Bleep bloop!', 'I am groot!', '42!'])}')")
312
- st.write("---")
 
 
 
313
 
314
  test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
315
  if st.button("Run Test ▶️"):
316
- result = st.session_state['builder'].evaluate(test_prompt)
317
- st.write(f"**Generated Response**: {result} (Titan’s wisdom unleashed!)")
 
 
 
318
 
319
  if st.button("Export Titan Files 📦"):
320
  config = st.session_state['builder'].config
@@ -361,11 +393,9 @@ with tab4:
361
  try:
362
  from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool
363
 
364
- # Load a tiny model (default to SmolLM-135M for speed)
365
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
366
  model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")
367
 
368
- # Define Agentic RAG agent with a witty twist
369
  agent = CodeAgent(
370
  model=model,
371
  tokenizer=tokenizer,
 
16
  import math
17
  from PIL import Image
18
  import random
19
+ import logging
20
 
21
+ # Set up logging for feedback
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Page Configuration with Humor
26
  st.set_page_config(
27
  page_title="SFT Tiny Titans 🚀",
28
  page_icon="🤖",
 
47
  def model_path(self):
48
  return f"models/{self.name}"
49
 
50
+ # Custom Dataset for SFT
51
  class SFTDataset(Dataset):
52
  def __init__(self, data, tokenizer, max_length=128):
53
  self.data = data
 
61
  prompt = self.data[idx]["prompt"]
62
  response = self.data[idx]["response"]
63
 
 
64
  full_text = f"{prompt} {response}"
65
  full_encoding = self.tokenizer(
66
  full_text,
 
70
  return_tensors="pt"
71
  )
72
 
 
73
  prompt_encoding = self.tokenizer(
74
  prompt,
75
  max_length=self.max_length,
76
+ padding=False,
77
  truncation=True,
78
  return_tensors="pt"
79
  )
80
 
81
  input_ids = full_encoding["input_ids"].squeeze()
82
  attention_mask = full_encoding["attention_mask"].squeeze()
83
+ labels = input_ids.clone()
84
 
85
+ prompt_len = prompt_encoding["input_ids"].shape[1]
 
86
  if prompt_len < self.max_length:
87
+ labels[:prompt_len] = -100
88
 
89
  return {
90
  "input_ids": input_ids,
 
135
  attention_mask = batch["attention_mask"].to(device)
136
  labels = batch["labels"].to(device)
137
 
 
138
  assert input_ids.shape[0] == labels.shape[0], f"Batch size mismatch: input_ids {input_ids.shape}, labels {labels.shape}"
139
 
140
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
 
153
  self.tokenizer.save_pretrained(path)
154
  st.success(f"Model saved at {path}! ✅ May the force be with it.")
155
 
156
+ def evaluate(self, prompt: str, status_container=None):
157
+ """Evaluate with feedback"""
158
  self.model.eval()
159
+ if status_container:
160
+ status_container.write("Preparing to evaluate... 🧠 (Titan’s warming up its circuits!)")
161
+ logger.info(f"Evaluating prompt: {prompt}")
162
+
163
+ try:
164
+ with torch.no_grad():
165
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
166
+ if status_container:
167
+ status_container.write(f"Tokenized input shape: {inputs['input_ids'].shape} 📏")
168
+
169
+ outputs = self.model.generate(
170
+ **inputs,
171
+ max_new_tokens=50,
172
+ do_sample=True,
173
+ top_p=0.95,
174
+ temperature=0.7
175
+ )
176
+ if status_container:
177
+ status_container.write("Generation complete! Decoding response... 🗣")
178
+
179
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
180
+ logger.info(f"Generated response: {result}")
181
+ return result
182
+ except Exception as e:
183
+ logger.error(f"Evaluation error: {str(e)}")
184
+ if status_container:
185
+ status_container.error(f"Oops! Something broke: {str(e)} 💥 (Titan tripped over a wire!)")
186
+ return f"Error: {str(e)}"
187
 
188
  # Utility Functions with Wit
189
  def get_download_link(file_path, mime_type="text/plain", label="Download"):
 
235
  if img_files:
236
  img_cols = st.sidebar.slider("Image Columns 📸", 1, 5, 3)
237
  cols = st.sidebar.columns(img_cols)
238
+ for idx, img_file in enumerate(img_files[:img_cols * 2]):
239
  with cols[idx % img_cols]:
240
  st.image(Image.open(img_file), caption=f"{img_file} 🖼", use_column_width=True)
241
 
242
  st.sidebar.subheader("CSV Gallery 📊")
243
  csv_files = get_gallery_files(["csv"])
244
  if csv_files:
245
+ for csv_file in csv_files[:5]:
246
  st.sidebar.markdown(get_download_link(csv_file, "text/csv", f"{csv_file} 📊"), unsafe_allow_html=True)
247
 
248
  st.sidebar.subheader("Model Management 🗂️")
 
328
  else:
329
  if st.session_state['builder'].sft_data:
330
  st.write("Testing with SFT Data:")
331
+ with st.spinner("Running SFT data tests... ⏳ (Titan’s flexing its brain muscles!)"):
332
+ for item in st.session_state['builder'].sft_data[:3]:
333
+ prompt = item["prompt"]
334
+ expected = item["response"]
335
+ status_container = st.empty()
336
+ generated = st.session_state['builder'].evaluate(prompt, status_container)
337
+ st.write(f"**Prompt**: {prompt}")
338
+ st.write(f"**Expected**: {expected}")
339
+ st.write(f"**Generated**: {generated} (Titan says: '{random.choice(['Bleep bloop!', 'I am groot!', '42!'])}')")
340
+ st.write("---")
341
+ status_container.empty() # Clear status after each test
342
 
343
  test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
344
  if st.button("Run Test ▶️"):
345
+ with st.spinner("Testing your prompt... ⏳ (Titan’s pondering deeply!)"):
346
+ status_container = st.empty()
347
+ result = st.session_state['builder'].evaluate(test_prompt, status_container)
348
+ st.write(f"**Generated Response**: {result} (Titan’s wisdom unleashed!)")
349
+ status_container.empty()
350
 
351
  if st.button("Export Titan Files 📦"):
352
  config = st.session_state['builder'].config
 
393
  try:
394
  from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool
395
 
 
396
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
397
  model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")
398
 
 
399
  agent = CodeAgent(
400
  model=model,
401
  tokenizer=tokenizer,