awacke1 commited on
Commit
f243316
·
verified ·
1 Parent(s): bcd8d5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +351 -338
app.py CHANGED
@@ -17,466 +17,479 @@ import math
17
  from PIL import Image
18
  import random
19
  import logging
20
- from datetime import datetime
21
- import pytz
22
- from diffusers import StableDiffusionPipeline
23
- from urllib.parse import quote
24
- import cv2
25
 
26
- # Logging setup
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
29
 
30
- # Page Configuration
31
- st.set_page_config(page_title="SFT Tiny Titans 🚀", page_icon="🤖", layout="wide", initial_sidebar_state="expanded")
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Model Configurations
34
  @dataclass
35
  class ModelConfig:
36
  name: str
37
  base_model: str
38
  size: str
39
  domain: Optional[str] = None
40
- model_type: str = "causal_lm"
41
  @property
42
  def model_path(self):
43
  return f"models/{self.name}"
44
 
45
- @dataclass
46
- class DiffusionConfig:
47
- name: str
48
- base_model: str
49
- size: str
50
- @property
51
- def model_path(self):
52
- return f"diffusion_models/{self.name}"
53
-
54
- # Datasets
55
  class SFTDataset(Dataset):
56
  def __init__(self, data, tokenizer, max_length=128):
57
  self.data = data
58
  self.tokenizer = tokenizer
59
  self.max_length = max_length
 
60
  def __len__(self):
61
  return len(self.data)
 
62
  def __getitem__(self, idx):
63
  prompt = self.data[idx]["prompt"]
64
  response = self.data[idx]["response"]
 
65
  full_text = f"{prompt} {response}"
66
- full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
67
- prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  input_ids = full_encoding["input_ids"].squeeze()
69
  attention_mask = full_encoding["attention_mask"].squeeze()
70
  labels = input_ids.clone()
 
71
  prompt_len = prompt_encoding["input_ids"].shape[1]
72
  if prompt_len < self.max_length:
73
  labels[:prompt_len] = -100
74
- return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
75
-
76
- class DiffusionDataset(Dataset):
77
- def __init__(self, images, texts):
78
- self.images = images
79
- self.texts = texts
80
- def __len__(self):
81
- return len(self.images)
82
- def __getitem__(self, idx):
83
- return {"image": self.images[idx], "text": self.texts[idx]}
84
 
85
- # Model Builders
86
  class ModelBuilder:
87
  def __init__(self):
88
  self.config = None
89
  self.model = None
90
  self.tokenizer = None
91
  self.sft_data = None
 
 
92
  def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
93
- self.model = AutoModelForCausalLM.from_pretrained(model_path)
94
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
95
- if self.tokenizer.pad_token is None:
96
- self.tokenizer.pad_token = self.tokenizer.eos_token
97
- if config:
98
- self.config = config
 
 
99
  return self
 
100
  def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
101
  self.sft_data = []
102
  with open(csv_path, "r") as f:
103
  reader = csv.DictReader(f)
104
  for row in reader:
105
  self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
 
106
  dataset = SFTDataset(self.sft_data, self.tokenizer)
107
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
108
  optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
 
109
  self.model.train()
110
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
  self.model.to(device)
112
  for epoch in range(epochs):
113
- total_loss = 0
114
- for batch in dataloader:
115
- optimizer.zero_grad()
116
- input_ids = batch["input_ids"].to(device)
117
- attention_mask = batch["attention_mask"].to(device)
118
- labels = batch["labels"].to(device)
119
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
120
- loss = outputs.loss
121
- loss.backward()
122
- optimizer.step()
123
- total_loss += loss.item()
124
- st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
 
 
 
 
 
125
  return self
126
- def save_model(self, path: str):
127
- os.makedirs(os.path.dirname(path), exist_ok=True)
128
- self.model.save_pretrained(path)
129
- self.tokenizer.save_pretrained(path)
130
- def evaluate(self, prompt: str):
131
- self.model.eval()
132
- with torch.no_grad():
133
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
134
- outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
135
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
136
 
137
- class DiffusionBuilder:
138
- def __init__(self):
139
- self.config = None
140
- self.pipeline = None
141
- def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
142
- self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
143
- self.pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
144
- if config:
145
- self.config = config
146
- return self
147
- def fine_tune_sft(self, images, texts, epochs=3):
148
- dataset = DiffusionDataset(images, texts)
149
- dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
150
- optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
151
- self.pipeline.unet.train()
152
- for epoch in range(epochs):
153
- total_loss = 0
154
- for batch in dataloader:
155
- optimizer.zero_grad()
156
- image = batch["image"].to(self.pipeline.device)
157
- text = batch["text"]
158
- latents = self.pipeline.vae.encode(image).latent_dist.sample()
159
- noise = torch.randn_like(latents)
160
- timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
161
- noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
162
- text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
163
- pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
164
- loss = torch.nn.functional.mse_loss(pred_noise, noise)
165
- loss.backward()
166
- optimizer.step()
167
- total_loss += loss.item()
168
- st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
169
- return self
170
  def save_model(self, path: str):
171
- os.makedirs(os.path.dirname(path), exist_ok=True)
172
- self.pipeline.save_pretrained(path)
173
- def generate(self, prompt: str):
174
- return self.pipeline(prompt, num_inference_steps=50).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- # Utilities
177
  def get_download_link(file_path, mime_type="text/plain", label="Download"):
178
  with open(file_path, 'rb') as f:
179
  data = f.read()
180
  b64 = base64.b64encode(data).decode()
181
- return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
182
 
183
  def zip_directory(directory_path, zip_path):
184
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
185
  for root, _, files in os.walk(directory_path):
186
  for file in files:
187
- zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.dirname(directory_path)))
 
 
188
 
189
- def get_model_files(model_type="causal_lm"):
190
- path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
191
- return [d for d in glob.glob(path) if os.path.isdir(d)]
192
 
193
  def get_gallery_files(file_types):
194
- return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
195
-
196
- def generate_filename(text_line):
197
- central = pytz.timezone('US/Central')
198
- timestamp = datetime.now(central).strftime("%Y%m%d_%I%M%S_%p")
199
- safe_text = ''.join(c if c.isalnum() else '_' for c in text_line[:50])
200
- return f"{timestamp}_{safe_text}.png"
201
-
202
- def display_search_links(query):
203
- search_urls = {
204
- "ArXiv": f"https://arxiv.org/search/?query={quote(query)}",
205
- "Wikipedia": f"https://en.wikipedia.org/wiki/{quote(query)}",
206
- "Google": f"https://www.google.com/search?q={quote(query)}",
207
- "YouTube": f"https://www.youtube.com/results?search_query={quote(query)}"
208
- }
209
- return ' '.join([f"[{name}]({url})" for name, url in search_urls.items()])
210
-
211
- def detect_cameras():
212
- cameras = []
213
- for i in range(2): # Check first two indices
214
- cap = cv2.VideoCapture(i)
215
- if cap.isOpened():
216
- cameras.append(i)
217
- cap.release()
218
- return cameras
219
-
220
- # Agent Classes
221
- class NLPAgent:
 
 
 
 
 
 
 
222
  def __init__(self, model, tokenizer):
223
  self.model = model
224
  self.tokenizer = tokenizer
225
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
  self.model.to(self.device)
 
227
  def generate(self, prompt: str) -> str:
228
  self.model.eval()
229
  with torch.no_grad():
230
  inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
231
- outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
 
 
 
 
 
 
232
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
233
  def plan_party(self, task: str) -> pd.DataFrame:
234
- search_result = "Latest trends for 2025: Gold-plated Batman statues, VR superhero battles."
235
- prompt = f"Given this context: '{search_result}'\n{task}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  plan_text = self.generate(prompt)
237
- st.markdown(f"Search Links: {display_search_links('superhero party trends')}", unsafe_allow_html=True)
238
- locations = {"Wayne Manor": (42.3601, -71.0589), "New York": (40.7128, -74.0060)}
239
- travel_times = {loc: calculate_cargo_travel_time(coords, locations["Wayne Manor"]) for loc, coords in locations.items() if loc != "Wayne Manor"}
240
- data = [
241
- {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Idea": "Gold-plated Batman statues"},
242
- {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Idea": "VR superhero battles"}
 
243
  ]
244
- return pd.DataFrame(data)
245
-
246
- class CVAgent:
247
- def __init__(self, pipeline):
248
- self.pipeline = pipeline
249
- def generate(self, prompt: str) -> Image.Image:
250
- return self.pipeline(prompt, num_inference_steps=50).images[0]
251
- def enhance_images(self, task: str) -> pd.DataFrame:
252
- search_result = "Latest superhero art trends: Neon outlines, 3D holograms."
253
- prompt = f"Given this context: '{search_result}'\n{task}"
254
- st.markdown(f"Search Links: {display_search_links('superhero art trends')}", unsafe_allow_html=True)
255
  data = [
256
- {"Image Theme": "Batman", "Enhancement": "Neon outlines"},
257
- {"Image Theme": "Iron Man", "Enhancement": "3D holograms"}
 
 
 
 
258
  ]
 
259
  return pd.DataFrame(data)
260
 
261
- def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
262
- def to_radians(degrees: float) -> float:
263
- return degrees * (math.pi / 180)
264
- lat1, lon1 = map(to_radians, origin_coords)
265
- lat2, lon2 = map(to_radians, destination_coords)
266
- EARTH_RADIUS_KM = 6371.0
267
- dlon = lon2 - lon1
268
- dlat = lat2 - lat1
269
- a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
270
- c = 2 * math.asin(math.sqrt(a))
271
- distance = EARTH_RADIUS_KM * c
272
- actual_distance = distance * 1.1
273
- flight_time = (actual_distance / cruising_speed_kmh) + 1.0
274
- return round(flight_time, 2)
275
-
276
  # Main App
277
  st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
278
 
279
- # Sidebar Galleries
280
- st.sidebar.header("Shared Galleries 🎨")
281
- for gallery_type, file_types, emoji in [
282
- ("Images 📸", ["png", "jpg", "jpeg"], "🖼️"),
283
- ("Videos 🎥", ["mp4"], "🎬"),
284
- ("Audio 🎶", ["mp3"], "🎵")
285
- ]:
286
- st.sidebar.subheader(f"{gallery_type} {emoji}")
287
- files = get_gallery_files(file_types)
288
- if files:
289
- cols_num = st.sidebar.slider(f"{gallery_type} Columns", 1, 5, 3, key=f"{gallery_type}_cols")
290
- cols = st.sidebar.columns(cols_num)
291
- for idx, file in enumerate(files[:cols_num * 2]):
292
- with cols[idx % cols_num]:
293
- if "Images" in gallery_type:
294
- st.image(Image.open(file), caption=file, use_column_width=True)
295
- elif "Videos" in gallery_type:
296
- st.video(file)
297
- elif "Audio" in gallery_type:
298
- st.audio(file)
299
 
300
  st.sidebar.subheader("Model Management 🗂️")
301
- model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
302
- model_dirs = get_model_files("causal_lm" if "NLP" in model_type else "diffusion")
303
  selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
304
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
305
- builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
306
- config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
307
- builder.load_model(selected_model, config)
308
- st.session_state['builder'] = builder
309
  st.session_state['model_loaded'] = True
310
  st.rerun()
311
 
312
- # Tabs
313
- tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([
314
- "Build Titan 🌱",
315
- "Fine-Tune NLP 🧠",
316
- "Fine-Tune CV 🎨",
317
- "Test Titans 🧪",
318
- "Agentic RAG 🌀",
319
- "Camera Inputs 📷"
320
- ])
321
 
322
  with tab1:
323
- st.header("Build Your Titan 🌱")
324
- model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
325
  base_model = st.selectbox(
326
  "Select Tiny Model",
327
- ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if "NLP" in model_type else ["stabilityai/stable-diffusion-2-1", "CompVis/stable-diffusion-v1-4"]
 
328
  )
329
  model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
 
 
330
  if st.button("Download Model ⬇️"):
331
- config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=model_name, base_model=base_model, size="small")
332
- builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
333
  builder.load_model(base_model, config)
334
  builder.save_model(config.model_path)
335
  st.session_state['builder'] = builder
336
  st.session_state['model_loaded'] = True
 
337
  st.rerun()
338
 
339
  with tab2:
340
- st.header("Fine-Tune NLP Titan 🧠 (Word Wizardry!)")
341
- if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], ModelBuilder):
342
- st.warning("Load an NLP Titan first! ⚠️")
343
  else:
344
- uploaded_csv = st.file_uploader("Upload CSV for NLP SFT", type="csv", key="nlp_csv")
345
- if uploaded_csv and st.button("Tune the Wordsmith 🔧"):
346
- csv_path = f"nlp_sft_data_{int(time.time())}.csv"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  with open(csv_path, "wb") as f:
348
  f.write(uploaded_csv.read())
349
  new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
350
- new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
 
 
 
 
 
351
  st.session_state['builder'].config = new_config
352
- st.session_state['builder'].fine_tune_sft(csv_path)
353
- st.session_state['builder'].save_model(new_config.model_path)
 
 
 
354
  zip_path = f"{new_config.model_path}.zip"
355
  zip_directory(new_config.model_path, zip_path)
356
- st.markdown(get_download_link(zip_path, "application/zip", "Download Tuned NLP Titan"), unsafe_allow_html=True)
 
357
 
358
  with tab3:
359
- st.header("Fine-Tune CV Titan 🎨 (Vision Vibes!)")
360
- if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], DiffusionBuilder):
361
- st.warning("Load a CV Titan first! ⚠️")
362
  else:
363
- uploaded_files = st.file_uploader("Upload Images/Videos", type=["png", "jpg", "jpeg", "mp4", "mp3"], accept_multiple_files=True, key="cv_upload")
364
- text_input = st.text_area("Enter Text (one line per image)", "Batman Neon\nIron Man Hologram\nThor Lightning", key="cv_text")
365
- if uploaded_files and st.button("Tune the Visionary 🖌️"):
366
- images = [Image.open(f) for f in uploaded_files if f.type.startswith("image")]
367
- texts = text_input.splitlines()
368
- if len(images) > len(texts):
369
- texts.extend([""] * (len(images) - len(texts)))
370
- elif len(texts) > len(images):
371
- texts = texts[:len(images)]
372
- st.session_state['builder'].fine_tune_sft(images, texts)
373
- new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
374
- new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
375
- st.session_state['builder'].config = new_config
376
- st.session_state['builder'].save_model(new_config.model_path)
377
- for img, text in zip(images, texts):
378
- filename = generate_filename(text)
379
- img.save(filename)
380
- st.image(img, caption=filename)
381
- zip_path = f"{new_config.model_path}.zip"
382
- zip_directory(new_config.model_path, zip_path)
383
- st.markdown(get_download_link(zip_path, "application/zip", "Download Tuned CV Titan"), unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
  with tab4:
386
- st.header("Test Titans 🧪 (Brains & Eyes!)")
387
- if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
388
- st.warning("Load a Titan first! ⚠️")
389
- else:
390
- if isinstance(st.session_state['builder'], ModelBuilder):
391
- st.subheader("NLP Test 🧠")
392
- test_prompt = st.text_area("Enter NLP Prompt", "Plan a superhero party!", key="nlp_test")
393
- if st.button("Test NLP Titan ▶️"):
394
- result = st.session_state['builder'].evaluate(test_prompt)
395
- st.write(f"**Response**: {result}")
396
- elif isinstance(st.session_state['builder'], DiffusionBuilder):
397
- st.subheader("CV Test 🎨")
398
- test_prompt = st.text_area("Enter CV Prompt", "Superhero in neon style", key="cv_test")
399
- if st.button("Test CV Titan ▶️"):
400
- image = st.session_state['builder'].generate(test_prompt)
401
- st.image(image, caption="Generated Image")
402
-
403
- cameras = detect_cameras()
404
- if cameras:
405
- st.subheader("Camera Snapshot Test 📷")
406
- camera_idx = st.selectbox("Select Camera", cameras, key="camera_select")
407
- snapshot_text = st.text_input("Snapshot Text", "Camera Snap", key="snap_text")
408
- if st.button("Capture Snapshot 📸"):
409
- cap = cv2.VideoCapture(camera_idx)
410
- ret, frame = cap.read()
411
- if ret:
412
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
413
- img = Image.fromarray(rgb_frame)
414
- filename = generate_filename(snapshot_text)
415
- img.save(filename)
416
- st.image(img, caption=filename)
417
- cap.release()
418
-
419
- with tab5:
420
- st.header("Agentic RAG 🌀 (Smart Plans & Visions!)")
421
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
422
- st.warning("Load a Titan first! ⚠️")
423
  else:
424
- if isinstance(st.session_state['builder'], ModelBuilder):
425
- st.subheader("NLP RAG Party 🧠")
426
- if st.button("Run NLP RAG Demo 🎉"):
427
- agent = NLPAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
428
- task = "Plan a luxury superhero-themed party at Wayne Manor."
429
- plan_df = agent.plan_party(task)
430
- st.dataframe(plan_df)
431
- elif isinstance(st.session_state['builder'], DiffusionBuilder):
432
- st.subheader("CV RAG Enhance 🎨")
433
- if st.button("Run CV RAG Demo 🖌️"):
434
- agent = CVAgent(st.session_state['builder'].pipeline)
435
- task = "Enhance superhero images with 2025 trends."
436
- enhance_df = agent.enhance_images(task)
437
- st.dataframe(enhance_df)
438
-
439
- with tab6:
440
- st.header("Camera Inputs 📷 (Live Feed Fun!)")
441
- cameras = detect_cameras()
442
- if not cameras:
443
- st.warning("No cameras detected! ⚠️")
444
- else:
445
- st.write(f"Detected {len(cameras)} cameras!")
446
- for idx in cameras:
447
- st.subheader(f"Camera {idx}")
448
- cap = cv2.VideoCapture(idx)
449
- if st.button(f"Capture from Camera {idx} 📸", key=f"cap_{idx}"):
450
- ret, frame = cap.read()
451
- if ret:
452
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
453
- img = Image.fromarray(rgb_frame)
454
- filename = generate_filename(f"Camera_{idx}_snap")
455
- img.save(filename)
456
- st.image(img, caption=filename)
457
- cap.release()
458
-
459
- # Preload demo files
460
- demo_images = ["20250319_010000_AM_Batman.png", "20250319_010001_AM_IronMan.png", "20250319_010002_AM_Thor.png"]
461
- demo_videos = ["20250319_010000_AM_Batman.mp4", "20250319_010001_AM_IronMan.mp4", "20250319_010002_AM_Thor.mp4"]
462
- for img in demo_images:
463
- if not os.path.exists(img):
464
- Image.new("RGB", (100, 100)).save(img)
465
- for vid in demo_videos:
466
- if not os.path.exists(vid):
467
- with open(vid, "wb") as f:
468
- f.write(b"") # Dummy file
469
-
470
- # Demo SFT Dataset
471
- st.subheader("Diffusion SFT Demo Dataset 🎨")
472
- demo_texts = ["Batman Neon", "Iron Man Hologram", "Thor Lightning"]
473
- demo_code = "\n".join([f"{i+1}. {text} -> {demo_images[i]}" for i, text in enumerate(demo_texts)])
474
- st.code(demo_code, language="text")
475
- if st.button("Download Demo CSV 📝"):
476
- csv_path = f"demo_diffusion_sft_{int(time.time())}.csv"
477
- with open(csv_path, "w", newline="") as f:
478
- writer = csv.writer(f)
479
- writer.writerow(["image", "text"])
480
- for img, text in zip(demo_images, demo_texts):
481
- writer.writerow([img, text])
482
- st.markdown(get_download_link(csv_path, "text/csv", "Download Demo CSV"), unsafe_allow_html=True)
 
17
  from PIL import Image
18
  import random
19
  import logging
 
 
 
 
 
20
 
21
+ # Set up logging for feedback
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
+ # Page Configuration with Humor
26
+ st.set_page_config(
27
+ page_title="SFT Tiny Titans 🚀",
28
+ page_icon="🤖",
29
+ layout="wide",
30
+ initial_sidebar_state="expanded",
31
+ menu_items={
32
+ 'Get Help': 'https://huggingface.co/awacke1',
33
+ 'Report a bug': 'https://huggingface.co/spaces/awacke1',
34
+ 'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
35
+ }
36
+ )
37
 
38
+ # Model Configuration Class
39
  @dataclass
40
  class ModelConfig:
41
  name: str
42
  base_model: str
43
  size: str
44
  domain: Optional[str] = None
45
+
46
  @property
47
  def model_path(self):
48
  return f"models/{self.name}"
49
 
50
+ # Custom Dataset for SFT
 
 
 
 
 
 
 
 
 
51
  class SFTDataset(Dataset):
52
  def __init__(self, data, tokenizer, max_length=128):
53
  self.data = data
54
  self.tokenizer = tokenizer
55
  self.max_length = max_length
56
+
57
  def __len__(self):
58
  return len(self.data)
59
+
60
  def __getitem__(self, idx):
61
  prompt = self.data[idx]["prompt"]
62
  response = self.data[idx]["response"]
63
+
64
  full_text = f"{prompt} {response}"
65
+ full_encoding = self.tokenizer(
66
+ full_text,
67
+ max_length=self.max_length,
68
+ padding="max_length",
69
+ truncation=True,
70
+ return_tensors="pt"
71
+ )
72
+
73
+ prompt_encoding = self.tokenizer(
74
+ prompt,
75
+ max_length=self.max_length,
76
+ padding=False,
77
+ truncation=True,
78
+ return_tensors="pt"
79
+ )
80
+
81
  input_ids = full_encoding["input_ids"].squeeze()
82
  attention_mask = full_encoding["attention_mask"].squeeze()
83
  labels = input_ids.clone()
84
+
85
  prompt_len = prompt_encoding["input_ids"].shape[1]
86
  if prompt_len < self.max_length:
87
  labels[:prompt_len] = -100
88
+
89
+ return {
90
+ "input_ids": input_ids,
91
+ "attention_mask": attention_mask,
92
+ "labels": labels
93
+ }
 
 
 
 
94
 
95
+ # Model Builder Class with Easter Egg Jokes
96
  class ModelBuilder:
97
  def __init__(self):
98
  self.config = None
99
  self.model = None
100
  self.tokenizer = None
101
  self.sft_data = None
102
+ self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", "Training complete! Time for a binary coffee break. ☕"]
103
+
104
  def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
105
+ with st.spinner(f"Loading {model_path}... (Patience, young padawan!)"):
106
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
107
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
108
+ if self.tokenizer.pad_token is None:
109
+ self.tokenizer.pad_token = self.tokenizer.eos_token
110
+ if config:
111
+ self.config = config
112
+ st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
113
  return self
114
+
115
  def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
116
  self.sft_data = []
117
  with open(csv_path, "r") as f:
118
  reader = csv.DictReader(f)
119
  for row in reader:
120
  self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
121
+
122
  dataset = SFTDataset(self.sft_data, self.tokenizer)
123
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
124
  optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
125
+
126
  self.model.train()
127
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
  self.model.to(device)
129
  for epoch in range(epochs):
130
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... ⚙️ (The AI is lifting weights!)"):
131
+ total_loss = 0
132
+ for batch in dataloader:
133
+ optimizer.zero_grad()
134
+ input_ids = batch["input_ids"].to(device)
135
+ attention_mask = batch["attention_mask"].to(device)
136
+ labels = batch["labels"].to(device)
137
+
138
+ assert input_ids.shape[0] == labels.shape[0], f"Batch size mismatch: input_ids {input_ids.shape}, labels {labels.shape}"
139
+
140
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
141
+ loss = outputs.loss
142
+ loss.backward()
143
+ optimizer.step()
144
+ total_loss += loss.item()
145
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
146
+ st.success(f"SFT Fine-tuning completed! 🎉 {random.choice(self.jokes)}")
147
  return self
 
 
 
 
 
 
 
 
 
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def save_model(self, path: str):
150
+ with st.spinner("Saving model... 💾 (Packing the AI’s suitcase!)"):
151
+ os.makedirs(os.path.dirname(path), exist_ok=True)
152
+ self.model.save_pretrained(path)
153
+ self.tokenizer.save_pretrained(path)
154
+ st.success(f"Model saved at {path}! ✅ May the force be with it.")
155
+
156
+ def evaluate(self, prompt: str, status_container=None):
157
+ """Evaluate with feedback"""
158
+ self.model.eval()
159
+ if status_container:
160
+ status_container.write("Preparing to evaluate... 🧠 (Titan’s warming up its circuits!)")
161
+ logger.info(f"Evaluating prompt: {prompt}")
162
+
163
+ try:
164
+ with torch.no_grad():
165
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
166
+ if status_container:
167
+ status_container.write(f"Tokenized input shape: {inputs['input_ids'].shape} 📏")
168
+
169
+ outputs = self.model.generate(
170
+ **inputs,
171
+ max_new_tokens=50,
172
+ do_sample=True,
173
+ top_p=0.95,
174
+ temperature=0.7
175
+ )
176
+ if status_container:
177
+ status_container.write("Generation complete! Decoding response... 🗣")
178
+
179
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
180
+ logger.info(f"Generated response: {result}")
181
+ return result
182
+ except Exception as e:
183
+ logger.error(f"Evaluation error: {str(e)}")
184
+ if status_container:
185
+ status_container.error(f"Oops! Something broke: {str(e)} 💥 (Titan tripped over a wire!)")
186
+ return f"Error: {str(e)}"
187
 
188
+ # Utility Functions with Wit
189
  def get_download_link(file_path, mime_type="text/plain", label="Download"):
190
  with open(file_path, 'rb') as f:
191
  data = f.read()
192
  b64 = base64.b64encode(data).decode()
193
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥 (Grab it before it runs away!)</a>'
194
 
195
  def zip_directory(directory_path, zip_path):
196
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
197
  for root, _, files in os.walk(directory_path):
198
  for file in files:
199
+ file_path = os.path.join(root, file)
200
+ arcname = os.path.relpath(file_path, os.path.dirname(directory_path))
201
+ zipf.write(file_path, arcname)
202
 
203
+ def get_model_files():
204
+ return [d for d in glob.glob("models/*") if os.path.isdir(d)]
 
205
 
206
  def get_gallery_files(file_types):
207
+ files = []
208
+ for ext in file_types:
209
+ files.extend(glob.glob(f"*.{ext}"))
210
+ return sorted(files)
211
+
212
+ # Cargo Travel Time Tool
213
+ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
214
+ def to_radians(degrees: float) -> float:
215
+ return degrees * (math.pi / 180)
216
+ lat1, lon1 = map(to_radians, origin_coords)
217
+ lat2, lon2 = map(to_radians, destination_coords)
218
+ EARTH_RADIUS_KM = 6371.0
219
+ dlon = lon2 - lon1
220
+ dlat = lat2 - lat1
221
+ a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
222
+ c = 2 * math.asin(math.sqrt(a))
223
+ distance = EARTH_RADIUS_KM * c
224
+ actual_distance = distance * 1.1
225
+ flight_time = (actual_distance / cruising_speed_kmh) + 1.0
226
+ return round(flight_time, 2)
227
+
228
+ # Mock Search Tool for RAG
229
+ def mock_duckduckgo_search(query: str) -> str:
230
+ """Simulate a search result for luxury superhero party trends"""
231
+ if "superhero party trends" in query.lower():
232
+ return """
233
+ Latest trends for 2025:
234
+ - Luxury decorations: Gold-plated Batman statues, holographic Avengers displays.
235
+ - Entertainment: Live stunt shows with Iron Man suits, VR superhero battles.
236
+ - Catering: Gourmet kryptonite-green cocktails, Thor’s hammer-shaped appetizers.
237
+ """
238
+ return "No relevant results found."
239
+
240
+ # Simple Agent Class for Demo
241
+ class PartyPlannerAgent:
242
  def __init__(self, model, tokenizer):
243
  self.model = model
244
  self.tokenizer = tokenizer
245
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
246
  self.model.to(self.device)
247
+
248
  def generate(self, prompt: str) -> str:
249
  self.model.eval()
250
  with torch.no_grad():
251
  inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
252
+ outputs = self.model.generate(
253
+ **inputs,
254
+ max_new_tokens=100,
255
+ do_sample=True,
256
+ top_p=0.95,
257
+ temperature=0.7
258
+ )
259
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
260
+
261
  def plan_party(self, task: str) -> pd.DataFrame:
262
+ # Mock search for context
263
+ search_result = mock_duckduckgo_search("latest superhero party trends")
264
+
265
+ # Locations and coordinates
266
+ locations = {
267
+ "Wayne Manor": (42.3601, -71.0589),
268
+ "New York": (40.7128, -74.0060),
269
+ "Los Angeles": (34.0522, -118.2437),
270
+ "London": (51.5074, -0.1278)
271
+ }
272
+
273
+ # Calculate travel times
274
+ wayne_coords = locations["Wayne Manor"]
275
+ travel_times = {
276
+ loc: calculate_cargo_travel_time(coords, wayne_coords)
277
+ for loc, coords in locations.items() if loc != "Wayne Manor"
278
+ }
279
+
280
+ # Generate luxury ideas with the SFT model
281
+ prompt = f"""
282
+ Given this context from a search: "{search_result}"
283
+ Plan a luxury superhero-themed party at Wayne Manor. Suggest luxury decorations, entertainment, and catering ideas.
284
+ """
285
  plan_text = self.generate(prompt)
286
+
287
+ # Parse plan into structured data (simplified)
288
+ catchphrases = [
289
+ "To the Batmobile!",
290
+ "Avengers, assemble!",
291
+ "I am Iron Man!",
292
+ "By the power of Grayskull!"
293
  ]
294
+
 
 
 
 
 
 
 
 
 
 
295
  data = [
296
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues", "Catchphrase": random.choice(catchphrases)},
297
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Holographic Avengers displays", "Catchphrase": random.choice(catchphrases)},
298
+ {"Location": "London", "Travel Time (hrs)": travel_times["London"], "Luxury Idea": "Live stunt shows with Iron Man suits", "Catchphrase": random.choice(catchphrases)},
299
+ {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "VR superhero battles", "Catchphrase": random.choice(catchphrases)},
300
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gourmet kryptonite-green cocktails", "Catchphrase": random.choice(catchphrases)},
301
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Thor’s hammer-shaped appetizers", "Catchphrase": random.choice(catchphrases)},
302
  ]
303
+
304
  return pd.DataFrame(data)
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  # Main App
307
  st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
308
 
309
+ # Sidebar with Galleries
310
+ st.sidebar.header("Galleries & Shenanigans 🎨")
311
+ st.sidebar.subheader("Image Gallery 📸")
312
+ img_files = get_gallery_files(["png", "jpg", "jpeg"])
313
+ if img_files:
314
+ img_cols = st.sidebar.slider("Image Columns 📸", 1, 5, 3)
315
+ cols = st.sidebar.columns(img_cols)
316
+ for idx, img_file in enumerate(img_files[:img_cols * 2]):
317
+ with cols[idx % img_cols]:
318
+ st.image(Image.open(img_file), caption=f"{img_file} 🖼", use_column_width=True)
319
+
320
+ st.sidebar.subheader("CSV Gallery 📊")
321
+ csv_files = get_gallery_files(["csv"])
322
+ if csv_files:
323
+ for csv_file in csv_files[:5]:
324
+ st.sidebar.markdown(get_download_link(csv_file, "text/csv", f"{csv_file} 📊"), unsafe_allow_html=True)
 
 
 
 
325
 
326
  st.sidebar.subheader("Model Management 🗂️")
327
+ model_dirs = get_model_files()
 
328
  selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
329
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
330
+ if 'builder' not in st.session_state:
331
+ st.session_state['builder'] = ModelBuilder()
332
+ config = ModelConfig(name=os.path.basename(selected_model), base_model="unknown", size="small", domain="general")
333
+ st.session_state['builder'].load_model(selected_model, config)
334
  st.session_state['model_loaded'] = True
335
  st.rerun()
336
 
337
+ # Main UI with Tabs
338
+ tab1, tab2, tab3, tab4 = st.tabs(["Build Tiny Titan 🌱", "Fine-Tune Titan 🔧", "Test Titan 🧪", "Agentic RAG Party 🌐"])
 
 
 
 
 
 
 
339
 
340
  with tab1:
341
+ st.header("Build Tiny Titan 🌱 (Assemble Your Mini-Mecha!)")
 
342
  base_model = st.selectbox(
343
  "Select Tiny Model",
344
+ ["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"],
345
+ help="Pick a pint-sized powerhouse (<1 GB)! SmolLM-135M (~270 MB), SmolLM-360M (~720 MB), Qwen1.5-0.5B (~1 GB)"
346
  )
347
  model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
348
+ domain = st.text_input("Target Domain", "general")
349
+
350
  if st.button("Download Model ⬇️"):
351
+ config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
352
+ builder = ModelBuilder()
353
  builder.load_model(base_model, config)
354
  builder.save_model(config.model_path)
355
  st.session_state['builder'] = builder
356
  st.session_state['model_loaded'] = True
357
+ st.success(f"Model downloaded and saved to {config.model_path}! 🎉 (Tiny but feisty!)")
358
  st.rerun()
359
 
360
  with tab2:
361
+ st.header("Fine-Tune Titan 🔧 (Teach Your Titan Some Tricks!)")
362
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
363
+ st.warning("Please build or load a Titan first! ⚠️ (No Titan, no party!)")
364
  else:
365
+ if st.button("Generate Sample CSV 📝"):
366
+ sample_data = [
367
+ {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human smarts in machines."},
368
+ {"prompt": "Explain machine learning", "response": "Machine learning is AI’s gym where models bulk up on data."},
369
+ {"prompt": "What is a neural network?", "response": "A neural network is a brainy AI mimicking human noggins."},
370
+ ]
371
+ csv_path = f"sft_data_{int(time.time())}.csv"
372
+ with open(csv_path, "w", newline="") as f:
373
+ writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
374
+ writer.writeheader()
375
+ writer.writerows(sample_data)
376
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
377
+ st.success(f"Sample CSV generated as {csv_path}! ✅ (Fresh from the data oven!)")
378
+
379
+ uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
380
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV 🔄"):
381
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
382
  with open(csv_path, "wb") as f:
383
  f.write(uploaded_csv.read())
384
  new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
385
+ new_config = ModelConfig(
386
+ name=new_model_name,
387
+ base_model=st.session_state['builder'].config.base_model,
388
+ size="small",
389
+ domain=st.session_state['builder'].config.domain
390
+ )
391
  st.session_state['builder'].config = new_config
392
+ with st.status("Fine-tuning Titan... ⏳ (Whipping it into shape!)", expanded=True) as status:
393
+ st.session_state['builder'].fine_tune_sft(csv_path)
394
+ st.session_state['builder'].save_model(new_config.model_path)
395
+ status.update(label="Fine-tuning completed! 🎉 (Titan’s ready to rumble!)", state="complete")
396
+
397
  zip_path = f"{new_config.model_path}.zip"
398
  zip_directory(new_config.model_path, zip_path)
399
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
400
+ st.rerun()
401
 
402
  with tab3:
403
+ st.header("Test Titan 🧪 (Put Your Titan to the Test!)")
404
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
405
+ st.warning("Please build or load a Titan first! ⚠️ (No Titan, no test drive!)")
406
  else:
407
+ if st.session_state['builder'].sft_data:
408
+ st.write("Testing with SFT Data:")
409
+ with st.spinner("Running SFT data tests... ⏳ (Titan’s flexing its brain muscles!)"):
410
+ for item in st.session_state['builder'].sft_data[:3]:
411
+ prompt = item["prompt"]
412
+ expected = item["response"]
413
+ status_container = st.empty()
414
+ generated = st.session_state['builder'].evaluate(prompt, status_container)
415
+ st.write(f"**Prompt**: {prompt}")
416
+ st.write(f"**Expected**: {expected}")
417
+ st.write(f"**Generated**: {generated} (Titan says: '{random.choice(['Bleep bloop!', 'I am groot!', '42!'])}')")
418
+ st.write("---")
419
+ status_container.empty() # Clear status after each test
420
+
421
+ test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
422
+ if st.button("Run Test ▶️"):
423
+ with st.spinner("Testing your prompt... ⏳ (Titan’s pondering deeply!)"):
424
+ status_container = st.empty()
425
+ result = st.session_state['builder'].evaluate(test_prompt, status_container)
426
+ st.write(f"**Generated Response**: {result} (Titan’s wisdom unleashed!)")
427
+ status_container.empty()
428
+
429
+ if st.button("Export Titan Files 📦"):
430
+ config = st.session_state['builder'].config
431
+ app_code = f"""
432
+ import streamlit as st
433
+ from transformers import AutoModelForCausalLM, AutoTokenizer
434
+
435
+ model = AutoModelForCausalLM.from_pretrained("{config.model_path}")
436
+ tokenizer = AutoTokenizer.from_pretrained("{config.model_path}")
437
+
438
+ st.title("Tiny Titan Demo")
439
+ input_text = st.text_area("Enter prompt")
440
+ if st.button("Generate"):
441
+ inputs = tokenizer(input_text, return_tensors="pt")
442
+ outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
443
+ st.write(tokenizer.decode(outputs[0], skip_special_tokens=True))
444
+ """
445
+ with open("titan_app.py", "w") as f:
446
+ f.write(app_code)
447
+ reqs = "streamlit\ntorch\ntransformers\n"
448
+ with open("titan_requirements.txt", "w") as f:
449
+ f.write(reqs)
450
+ readme = f"""
451
+ # Tiny Titan Demo
452
+
453
+ ## How to run
454
+ 1. Install requirements: `pip install -r titan_requirements.txt`
455
+ 2. Run the app: `streamlit run titan_app.py`
456
+ 3. Input a prompt and click "Generate". Watch the magic unfold! 🪄
457
+ """
458
+ with open("titan_README.md", "w") as f:
459
+ f.write(readme)
460
+
461
+ st.markdown(get_download_link("titan_app.py", "text/plain", "Download App"), unsafe_allow_html=True)
462
+ st.markdown(get_download_link("titan_requirements.txt", "text/plain", "Download Requirements"), unsafe_allow_html=True)
463
+ st.markdown(get_download_link("titan_README.md", "text/markdown", "Download README"), unsafe_allow_html=True)
464
+ st.success("Titan files exported! ✅ (Ready to conquer the galaxy!)")
465
 
466
  with tab4:
467
+ st.header("Agentic RAG Party 🌐 (Party Like It’s 2099!)")
468
+ st.write("This demo uses your SFT-tuned Tiny Titan to plan a superhero party with mock retrieval!")
469
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
471
+ st.warning("Please build or load a Titan first! ⚠️ (No Titan, no party!)")
472
  else:
473
+ if st.button("Run Agentic RAG Demo 🎉"):
474
+ with st.spinner("Loading your SFT-tuned Titan... ⏳ (Titan’s suiting up!)"):
475
+ agent = PartyPlannerAgent(
476
+ model=st.session_state['builder'].model,
477
+ tokenizer=st.session_state['builder'].tokenizer
478
+ )
479
+ st.write("Agent ready! 🦸‍♂️ (Time to plan an epic bash!)")
480
+
481
+ task = """
482
+ Plan a luxury superhero-themed party at Wayne Manor (42.3601° N, 71.0589° W).
483
+ Use mock search results for the latest superhero party trends, refine for luxury elements
484
+ (decorations, entertainment, catering), and calculate cargo travel times from key locations
485
+ (New York: 40.7128° N, 74.0060° W; LA: 34.0522° N, 118.2437° W; London: 51.5074° N, 0.1278° W)
486
+ to Wayne Manor. Create a plan with at least 6 entries in a pandas dataframe.
487
+ """
488
+ with st.spinner("Planning the ultimate superhero bash... ⏳ (Calling all caped crusaders!)"):
489
+ try:
490
+ plan_df = agent.plan_party(task)
491
+ st.write("Agentic RAG Party Plan:")
492
+ st.dataframe(plan_df)
493
+ st.write("Party on, Wayne! 🦸‍♂️🎉")
494
+ except Exception as e:
495
+ st.error(f"Error planning party: {str(e)} (Even Superman has kryptonite days!)")