awacke1 commited on
Commit
4cb9027
·
verified ·
1 Parent(s): cd06e37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -358
app.py CHANGED
@@ -1,33 +1,19 @@
1
  #!/usr/bin/env python3
2
  import os
3
- import shutil
4
- import glob
5
  import base64
6
  import streamlit as st
7
  import pandas as pd
8
- import torch
9
- from transformers import AutoModelForCausalLM, AutoTokenizer
10
- from torch.utils.data import Dataset, DataLoader
11
  import csv
12
  import time
13
  from dataclasses import dataclass
14
- from typing import Optional, Tuple
15
- import zipfile
16
- import math
17
  from PIL import Image
18
- import random
19
- import logging
20
  from datetime import datetime
21
  import pytz
22
- from diffusers import StableDiffusionPipeline
23
- from urllib.parse import quote
24
- import cv2
25
 
26
- # Logging setup
27
- logging.basicConfig(level=logging.INFO)
28
- logger = logging.getLogger(__name__)
29
 
30
- # Page Configuration
31
  st.set_page_config(page_title="SFT Tiny Titans 🚀", page_icon="🤖", layout="wide", initial_sidebar_state="expanded")
32
 
33
  # Model Configurations
@@ -35,8 +21,6 @@ st.set_page_config(page_title="SFT Tiny Titans 🚀", page_icon="🤖", layout="
35
  class ModelConfig:
36
  name: str
37
  base_model: str
38
- size: str
39
- domain: Optional[str] = None
40
  model_type: str = "causal_lm"
41
  @property
42
  def model_path(self):
@@ -46,132 +30,45 @@ class ModelConfig:
46
  class DiffusionConfig:
47
  name: str
48
  base_model: str
49
- size: str
50
  @property
51
  def model_path(self):
52
  return f"diffusion_models/{self.name}"
53
 
54
- # Datasets
55
- class SFTDataset(Dataset):
56
- def __init__(self, data, tokenizer, max_length=128):
57
- self.data = data
58
- self.tokenizer = tokenizer
59
- self.max_length = max_length
60
- def __len__(self):
61
- return len(self.data)
62
- def __getitem__(self, idx):
63
- prompt = self.data[idx]["prompt"]
64
- response = self.data[idx]["response"]
65
- full_text = f"{prompt} {response}"
66
- full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
67
- prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
68
- input_ids = full_encoding["input_ids"].squeeze()
69
- attention_mask = full_encoding["attention_mask"].squeeze()
70
- labels = input_ids.clone()
71
- prompt_len = prompt_encoding["input_ids"].shape[1]
72
- if prompt_len < self.max_length:
73
- labels[:prompt_len] = -100
74
- return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
75
-
76
- class DiffusionDataset(Dataset):
77
- def __init__(self, images, texts):
78
- self.images = images
79
- self.texts = texts
80
- def __len__(self):
81
- return len(self.images)
82
- def __getitem__(self, idx):
83
- return {"image": self.images[idx], "text": self.texts[idx]}
84
-
85
- # Model Builders
86
  class ModelBuilder:
87
  def __init__(self):
88
  self.config = None
89
  self.model = None
90
  self.tokenizer = None
91
- self.sft_data = None
92
- def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
 
93
  self.model = AutoModelForCausalLM.from_pretrained(model_path)
94
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
95
  if self.tokenizer.pad_token is None:
96
  self.tokenizer.pad_token = self.tokenizer.eos_token
97
- if config:
98
- self.config = config
99
- return self
100
- def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
101
- self.sft_data = []
102
- with open(csv_path, "r") as f:
103
- reader = csv.DictReader(f)
104
- for row in reader:
105
- self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
106
- dataset = SFTDataset(self.sft_data, self.tokenizer)
107
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
108
- optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
109
- self.model.train()
110
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
- self.model.to(device)
112
- for epoch in range(epochs):
113
- total_loss = 0
114
- for batch in dataloader:
115
- optimizer.zero_grad()
116
- input_ids = batch["input_ids"].to(device)
117
- attention_mask = batch["attention_mask"].to(device)
118
- labels = batch["labels"].to(device)
119
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
120
- loss = outputs.loss
121
- loss.backward()
122
- optimizer.step()
123
- total_loss += loss.item()
124
- st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
125
- return self
126
- def save_model(self, path: str):
127
- os.makedirs(os.path.dirname(path), exist_ok=True)
128
- self.model.save_pretrained(path)
129
- self.tokenizer.save_pretrained(path)
130
  def evaluate(self, prompt: str):
 
131
  self.model.eval()
132
  with torch.no_grad():
133
  inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
134
- outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
135
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
136
 
137
  class DiffusionBuilder:
138
  def __init__(self):
139
  self.config = None
140
  self.pipeline = None
141
- def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
 
 
142
  self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
143
- self.pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
144
- if config:
145
- self.config = config
146
- return self
147
- def fine_tune_sft(self, images, texts, epochs=3):
148
- dataset = DiffusionDataset(images, texts)
149
- dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
150
- optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
151
- self.pipeline.unet.train()
152
- for epoch in range(epochs):
153
- total_loss = 0
154
- for batch in dataloader:
155
- optimizer.zero_grad()
156
- image = batch["image"].to(self.pipeline.device)
157
- text = batch["text"]
158
- latents = self.pipeline.vae.encode(image).latent_dist.sample()
159
- noise = torch.randn_like(latents)
160
- timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
161
- noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
162
- text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
163
- pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
164
- loss = torch.nn.functional.mse_loss(pred_noise, noise)
165
- loss.backward()
166
- optimizer.step()
167
- total_loss += loss.item()
168
- st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
169
- return self
170
- def save_model(self, path: str):
171
- os.makedirs(os.path.dirname(path), exist_ok=True)
172
- self.pipeline.save_pretrained(path)
173
  def generate(self, prompt: str):
174
- return self.pipeline(prompt, num_inference_steps=50).images[0]
175
 
176
  # Utilities
177
  def get_download_link(file_path, mime_type="text/plain", label="Download"):
@@ -180,300 +77,191 @@ def get_download_link(file_path, mime_type="text/plain", label="Download"):
180
  b64 = base64.b64encode(data).decode()
181
  return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
182
 
183
- def zip_directory(directory_path, zip_path):
184
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
185
- for root, _, files in os.walk(directory_path):
186
- for file in files:
187
- zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.dirname(directory_path)))
188
-
189
- def get_model_files(model_type="causal_lm"):
190
- path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
191
- return [d for d in glob.glob(path) if os.path.isdir(d)]
192
-
193
- def get_gallery_files(file_types):
194
- return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
195
-
196
  def generate_filename(text_line):
197
  central = pytz.timezone('US/Central')
198
  timestamp = datetime.now(central).strftime("%Y%m%d_%I%M%S_%p")
199
  safe_text = ''.join(c if c.isalnum() else '_' for c in text_line[:50])
200
  return f"{timestamp}_{safe_text}.png"
201
 
202
- def display_search_links(query):
203
- search_urls = {
204
- "ArXiv": f"https://arxiv.org/search/?query={quote(query)}",
205
- "Wikipedia": f"https://en.wikipedia.org/wiki/{quote(query)}",
206
- "Google": f"https://www.google.com/search?q={quote(query)}",
207
- "YouTube": f"https://www.youtube.com/results?search_query={quote(query)}"
208
- }
209
- return ' '.join([f"[{name}]({url})" for name, url in search_urls.items()])
210
-
211
- def detect_cameras():
212
- cameras = []
213
- for i in range(2): # Check first two indices
214
- cap = cv2.VideoCapture(i)
215
- if cap.isOpened():
216
- cameras.append(i)
217
- cap.release()
218
- return cameras
219
-
220
- # Agent Classes
221
- class NLPAgent:
222
- def __init__(self, model, tokenizer):
223
- self.model = model
224
- self.tokenizer = tokenizer
225
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
- self.model.to(self.device)
227
- def generate(self, prompt: str) -> str:
228
- self.model.eval()
229
- with torch.no_grad():
230
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
231
- outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
232
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
233
- def plan_party(self, task: str) -> pd.DataFrame:
234
- search_result = "Latest trends for 2025: Gold-plated Batman statues, VR superhero battles."
235
- prompt = f"Given this context: '{search_result}'\n{task}"
236
- plan_text = self.generate(prompt)
237
- st.markdown(f"Search Links: {display_search_links('superhero party trends')}", unsafe_allow_html=True)
238
- locations = {"Wayne Manor": (42.3601, -71.0589), "New York": (40.7128, -74.0060)}
239
- travel_times = {loc: calculate_cargo_travel_time(coords, locations["Wayne Manor"]) for loc, coords in locations.items() if loc != "Wayne Manor"}
240
- data = [
241
- {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Idea": "Gold-plated Batman statues"},
242
- {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Idea": "VR superhero battles"}
243
- ]
244
- return pd.DataFrame(data)
245
-
246
- class CVAgent:
247
- def __init__(self, pipeline):
248
- self.pipeline = pipeline
249
- def generate(self, prompt: str) -> Image.Image:
250
- return self.pipeline(prompt, num_inference_steps=50).images[0]
251
- def enhance_images(self, task: str) -> pd.DataFrame:
252
- search_result = "Latest superhero art trends: Neon outlines, 3D holograms."
253
- prompt = f"Given this context: '{search_result}'\n{task}"
254
- st.markdown(f"Search Links: {display_search_links('superhero art trends')}", unsafe_allow_html=True)
255
- data = [
256
- {"Image Theme": "Batman", "Enhancement": "Neon outlines"},
257
- {"Image Theme": "Iron Man", "Enhancement": "3D holograms"}
258
- ]
259
- return pd.DataFrame(data)
260
 
261
- def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
262
- def to_radians(degrees: float) -> float:
263
- return degrees * (math.pi / 180)
264
- lat1, lon1 = map(to_radians, origin_coords)
265
- lat2, lon2 = map(to_radians, destination_coords)
266
- EARTH_RADIUS_KM = 6371.0
267
- dlon = lon2 - lon1
268
- dlat = lat2 - lat1
269
- a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
270
- c = 2 * math.asin(math.sqrt(a))
271
- distance = EARTH_RADIUS_KM * c
272
- actual_distance = distance * 1.1
273
- flight_time = (actual_distance / cruising_speed_kmh) + 1.0
274
- return round(flight_time, 2)
275
 
276
  # Main App
277
- st.title("SFT Tiny Titans 🚀 (Small but Mighty!)")
278
 
279
  # Sidebar Galleries
280
- st.sidebar.header("Shared Galleries 🎨")
281
  for gallery_type, file_types, emoji in [
282
  ("Images 📸", ["png", "jpg", "jpeg"], "🖼️"),
283
- ("Videos 🎥", ["mp4"], "🎬"),
284
- ("Audio 🎶", ["mp3"], "🎵")
285
  ]:
286
  st.sidebar.subheader(f"{gallery_type} {emoji}")
287
  files = get_gallery_files(file_types)
288
  if files:
289
- cols_num = st.sidebar.slider(f"{gallery_type} Columns", 1, 5, 3, key=f"{gallery_type}_cols")
290
- cols = st.sidebar.columns(cols_num)
291
- for idx, file in enumerate(files[:cols_num * 2]):
292
- with cols[idx % cols_num]:
293
  if "Images" in gallery_type:
294
- st.image(Image.open(file), caption=file, use_column_width=True)
295
  elif "Videos" in gallery_type:
296
  st.video(file)
297
- elif "Audio" in gallery_type:
298
- st.audio(file)
299
 
300
- st.sidebar.subheader("Model Management 🗂️")
 
301
  model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
302
- model_dirs = get_model_files("causal_lm" if "NLP" in model_type else "diffusion")
303
- selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
304
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
305
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
306
- config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
307
- builder.load_model(selected_model, config)
 
308
  st.session_state['builder'] = builder
309
  st.session_state['model_loaded'] = True
310
- st.rerun()
311
 
312
  # Tabs
313
- tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([
314
  "Build Titan 🌱",
315
- "Fine-Tune NLP 🧠",
316
- "Fine-Tune CV 🎨",
317
  "Test Titans 🧪",
318
- "Agentic RAG 🌀",
319
- "Camera Inputs 📷"
320
  ])
321
 
322
  with tab1:
323
- st.header("Build Your Titan 🌱")
324
  model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
325
- base_model = st.selectbox(
326
- "Select Tiny Model",
327
- ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if "NLP" in model_type else ["stabilityai/stable-diffusion-2-1", "CompVis/stable-diffusion-v1-4"]
328
- )
329
- model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
330
  if st.button("Download Model ⬇️"):
331
- config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=model_name, base_model=base_model, size="small")
332
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
333
- builder.load_model(base_model, config)
334
- builder.save_model(config.model_path)
335
  st.session_state['builder'] = builder
336
  st.session_state['model_loaded'] = True
337
- st.rerun()
338
 
339
  with tab2:
340
- st.header("Fine-Tune NLP Titan 🧠 (Word Wizardry!)")
341
- if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], ModelBuilder):
342
- st.warning("Load an NLP Titan first! ⚠️")
343
- else:
344
- uploaded_csv = st.file_uploader("Upload CSV for NLP SFT", type="csv", key="nlp_csv")
345
- if uploaded_csv and st.button("Tune the Wordsmith 🔧"):
346
- csv_path = f"nlp_sft_data_{int(time.time())}.csv"
347
- with open(csv_path, "wb") as f:
348
- f.write(uploaded_csv.read())
349
- new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
350
- new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
351
- st.session_state['builder'].config = new_config
352
- st.session_state['builder'].fine_tune_sft(csv_path)
353
- st.session_state['builder'].save_model(new_config.model_path)
354
- zip_path = f"{new_config.model_path}.zip"
355
- zip_directory(new_config.model_path, zip_path)
356
- st.markdown(get_download_link(zip_path, "application/zip", "Download Tuned NLP Titan"), unsafe_allow_html=True)
357
-
358
- with tab3:
359
- st.header("Fine-Tune CV Titan 🎨 (Vision Vibes!)")
360
- if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], DiffusionBuilder):
361
- st.warning("Load a CV Titan first! ⚠️")
362
- else:
363
- uploaded_files = st.file_uploader("Upload Images/Videos", type=["png", "jpg", "jpeg", "mp4", "mp3"], accept_multiple_files=True, key="cv_upload")
364
- text_input = st.text_area("Enter Text (one line per image)", "Batman Neon\nIron Man Hologram\nThor Lightning", key="cv_text")
365
- if uploaded_files and st.button("Tune the Visionary 🖌️"):
366
- images = [Image.open(f) for f in uploaded_files if f.type.startswith("image")]
367
- texts = text_input.splitlines()
368
- if len(images) > len(texts):
369
- texts.extend([""] * (len(images) - len(texts)))
370
- elif len(texts) > len(images):
371
- texts = texts[:len(images)]
372
- st.session_state['builder'].fine_tune_sft(images, texts)
373
- new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
374
- new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
375
- st.session_state['builder'].config = new_config
376
- st.session_state['builder'].save_model(new_config.model_path)
377
- for img, text in zip(images, texts):
378
- filename = generate_filename(text)
379
- img.save(filename)
380
- st.image(img, caption=filename)
381
- zip_path = f"{new_config.model_path}.zip"
382
- zip_directory(new_config.model_path, zip_path)
383
- st.markdown(get_download_link(zip_path, "application/zip", "Download Tuned CV Titan"), unsafe_allow_html=True)
384
-
385
- with tab4:
386
- st.header("Test Titans 🧪 (Brains & Eyes!)")
387
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
388
  st.warning("Load a Titan first! ⚠️")
389
  else:
390
  if isinstance(st.session_state['builder'], ModelBuilder):
391
- st.subheader("NLP Test 🧠")
392
- test_prompt = st.text_area("Enter NLP Prompt", "Plan a superhero party!", key="nlp_test")
393
- if st.button("Test NLP Titan ▶️"):
394
- result = st.session_state['builder'].evaluate(test_prompt)
395
- st.write(f"**Response**: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
397
- st.subheader("CV Test 🎨")
398
- test_prompt = st.text_area("Enter CV Prompt", "Superhero in neon style", key="cv_test")
399
- if st.button("Test CV Titan ▶️"):
400
- image = st.session_state['builder'].generate(test_prompt)
401
- st.image(image, caption="Generated Image")
402
-
403
- cameras = detect_cameras()
404
- if cameras:
405
- st.subheader("Camera Snapshot Test 📷")
406
- camera_idx = st.selectbox("Select Camera", cameras, key="camera_select")
407
- snapshot_text = st.text_input("Snapshot Text", "Camera Snap", key="snap_text")
408
- if st.button("Capture Snapshot 📸"):
409
- cap = cv2.VideoCapture(camera_idx)
410
- ret, frame = cap.read()
411
- if ret:
412
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
413
- img = Image.fromarray(rgb_frame)
414
- filename = generate_filename(snapshot_text)
415
- img.save(filename)
416
- st.image(img, caption=filename)
417
- cap.release()
 
 
 
 
418
 
419
- with tab5:
420
- st.header("Agentic RAG 🌀 (Smart Plans & Visions!)")
421
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
422
  st.warning("Load a Titan first! ⚠️")
423
  else:
424
  if isinstance(st.session_state['builder'], ModelBuilder):
425
- st.subheader("NLP RAG Party 🧠")
426
- if st.button("Run NLP RAG Demo 🎉"):
427
- agent = NLPAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
428
- task = "Plan a luxury superhero-themed party at Wayne Manor."
429
- plan_df = agent.plan_party(task)
430
- st.dataframe(plan_df)
431
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
432
- st.subheader("CV RAG Enhance 🎨")
433
- if st.button("Run CV RAG Demo 🖌️"):
434
- agent = CVAgent(st.session_state['builder'].pipeline)
435
- task = "Enhance superhero images with 2025 trends."
436
- enhance_df = agent.enhance_images(task)
437
- st.dataframe(enhance_df)
438
 
439
- with tab6:
440
- st.header("Camera Inputs 📷 (Live Feed Fun!)")
441
- cameras = detect_cameras()
442
- if not cameras:
443
- st.warning("No cameras detected! ⚠️")
444
- else:
445
- st.write(f"Detected {len(cameras)} cameras!")
446
- for idx in cameras:
447
- st.subheader(f"Camera {idx}")
448
- cap = cv2.VideoCapture(idx)
449
- if st.button(f"Capture from Camera {idx} 📸", key=f"cap_{idx}"):
450
- ret, frame = cap.read()
451
- if ret:
452
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
453
- img = Image.fromarray(rgb_frame)
454
- filename = generate_filename(f"Camera_{idx}_snap")
455
- img.save(filename)
456
- st.image(img, caption=filename)
457
- cap.release()
458
 
459
- # Preload demo files
460
- demo_images = ["20250319_010000_AM_Batman.png", "20250319_010001_AM_IronMan.png", "20250319_010002_AM_Thor.png"]
461
- demo_videos = ["20250319_010000_AM_Batman.mp4", "20250319_010001_AM_IronMan.mp4", "20250319_010002_AM_Thor.mp4"]
462
- for img in demo_images:
 
463
  if not os.path.exists(img):
464
  Image.new("RGB", (100, 100)).save(img)
465
- for vid in demo_videos:
466
- if not os.path.exists(vid):
467
- with open(vid, "wb") as f:
468
- f.write(b"") # Dummy file
469
-
470
- # Demo SFT Dataset
471
- st.subheader("Diffusion SFT Demo Dataset 🎨")
472
- demo_texts = ["Batman Neon", "Iron Man Hologram", "Thor Lightning"]
473
- demo_code = "\n".join([f"{i+1}. {text} -> {demo_images[i]}" for i, text in enumerate(demo_texts)])
474
- st.code(demo_code, language="text")
475
  if st.button("Download Demo CSV 📝"):
476
- csv_path = f"demo_diffusion_sft_{int(time.time())}.csv"
477
  with open(csv_path, "w", newline="") as f:
478
  writer = csv.writer(f)
479
  writer.writerow(["image", "text"])
 
1
  #!/usr/bin/env python3
2
  import os
 
 
3
  import base64
4
  import streamlit as st
5
  import pandas as pd
 
 
 
6
  import csv
7
  import time
8
  from dataclasses import dataclass
 
 
 
9
  from PIL import Image
 
 
10
  from datetime import datetime
11
  import pytz
12
+ from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
13
+ import av
 
14
 
15
+ # Minimal initial imports to reduce startup delay
 
 
16
 
 
17
  st.set_page_config(page_title="SFT Tiny Titans 🚀", page_icon="🤖", layout="wide", initial_sidebar_state="expanded")
18
 
19
  # Model Configurations
 
21
  class ModelConfig:
22
  name: str
23
  base_model: str
 
 
24
  model_type: str = "causal_lm"
25
  @property
26
  def model_path(self):
 
30
  class DiffusionConfig:
31
  name: str
32
  base_model: str
 
33
  @property
34
  def model_path(self):
35
  return f"diffusion_models/{self.name}"
36
 
37
+ # Lazy-loaded Builders
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class ModelBuilder:
39
  def __init__(self):
40
  self.config = None
41
  self.model = None
42
  self.tokenizer = None
43
+ def load_model(self, model_path: str, config: ModelConfig):
44
+ from transformers import AutoModelForCausalLM, AutoTokenizer
45
+ import torch
46
  self.model = AutoModelForCausalLM.from_pretrained(model_path)
47
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
48
  if self.tokenizer.pad_token is None:
49
  self.tokenizer.pad_token = self.tokenizer.eos_token
50
+ self.config = config
51
+ self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def evaluate(self, prompt: str):
53
+ import torch
54
  self.model.eval()
55
  with torch.no_grad():
56
  inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
57
+ outputs = self.model.generate(**inputs, max_new_tokens=50)
58
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
59
 
60
  class DiffusionBuilder:
61
  def __init__(self):
62
  self.config = None
63
  self.pipeline = None
64
+ def load_model(self, model_path: str, config: DiffusionConfig):
65
+ from diffusers import StableDiffusionPipeline
66
+ import torch
67
  self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
68
+ self.pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
69
+ self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def generate(self, prompt: str):
71
+ return self.pipeline(prompt, num_inference_steps=20).images[0]
72
 
73
  # Utilities
74
  def get_download_link(file_path, mime_type="text/plain", label="Download"):
 
77
  b64 = base64.b64encode(data).decode()
78
  return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} 📥</a>'
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def generate_filename(text_line):
81
  central = pytz.timezone('US/Central')
82
  timestamp = datetime.now(central).strftime("%Y%m%d_%I%M%S_%p")
83
  safe_text = ''.join(c if c.isalnum() else '_' for c in text_line[:50])
84
  return f"{timestamp}_{safe_text}.png"
85
 
86
+ def get_gallery_files(file_types):
87
+ return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Video Transformer for WebRTC
90
+ class VideoSnapshot(VideoTransformerBase):
91
+ def __init__(self):
92
+ self.snapshot = None
93
+ def transform(self, frame):
94
+ img = frame.to_ndarray(format="bgr24")
95
+ return img
96
+ def take_snapshot(self):
97
+ if self.snapshot is not None:
98
+ return Image.fromarray(self.snapshot)
 
 
 
 
99
 
100
  # Main App
101
+ st.title("SFT Tiny Titans 🚀 (Lean & Mean!)")
102
 
103
  # Sidebar Galleries
104
+ st.sidebar.header("Media Gallery 🎨")
105
  for gallery_type, file_types, emoji in [
106
  ("Images 📸", ["png", "jpg", "jpeg"], "🖼️"),
107
+ ("Videos 🎥", ["mp4"], "🎬")
 
108
  ]:
109
  st.sidebar.subheader(f"{gallery_type} {emoji}")
110
  files = get_gallery_files(file_types)
111
  if files:
112
+ cols = st.sidebar.columns(3)
113
+ for idx, file in enumerate(files[:6]):
114
+ with cols[idx % 3]:
 
115
  if "Images" in gallery_type:
116
+ st.image(Image.open(file), caption=file.split('/')[-1], use_column_width=True)
117
  elif "Videos" in gallery_type:
118
  st.video(file)
 
 
119
 
120
+ # Sidebar Model Management
121
+ st.sidebar.subheader("Model Hub 🗂️")
122
  model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
123
+ model_options = ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if "NLP" in model_type else ["stabilityai/stable-diffusion-2-1", "CompVis/stable-diffusion-v1-4"]
124
+ selected_model = st.sidebar.selectbox("Select Model", ["None"] + model_options)
125
  if selected_model != "None" and st.sidebar.button("Load Model 📂"):
126
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
127
+ config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model)
128
+ with st.spinner("Loading... ⏳"):
129
+ builder.load_model(selected_model, config)
130
  st.session_state['builder'] = builder
131
  st.session_state['model_loaded'] = True
 
132
 
133
  # Tabs
134
+ tab1, tab2, tab3, tab4 = st.tabs([
135
  "Build Titan 🌱",
136
+ "Fine-Tune Titans 🔧",
 
137
  "Test Titans 🧪",
138
+ "Camera Snap 📷"
 
139
  ])
140
 
141
  with tab1:
142
+ st.header("Build Titan 🌱 (Start Small!)")
143
  model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
144
+ base_model = st.selectbox("Select Model", model_options, key="build_model")
 
 
 
 
145
  if st.button("Download Model ⬇️"):
146
+ config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model)
147
  builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
148
+ with st.spinner("Fetching... ⏳"):
149
+ builder.load_model(base_model, config)
150
  st.session_state['builder'] = builder
151
  st.session_state['model_loaded'] = True
152
+ st.success("Titan ready! 🎉")
153
 
154
  with tab2:
155
+ st.header("Fine-Tune Titans 🔧 (Sharpen Up!)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
157
  st.warning("Load a Titan first! ⚠️")
158
  else:
159
  if isinstance(st.session_state['builder'], ModelBuilder):
160
+ st.subheader("NLP Tune 🧠")
161
+ uploaded_csv = st.file_uploader("Upload CSV", type="csv", key="nlp_csv")
162
+ if uploaded_csv and st.button("Tune NLP 🔄"):
163
+ from torch.utils.data import Dataset, DataLoader
164
+ import torch
165
+ class SFTDataset(Dataset):
166
+ def __init__(self, data, tokenizer):
167
+ self.data = data
168
+ self.tokenizer = tokenizer
169
+ def __len__(self):
170
+ return len(self.data)
171
+ def __getitem__(self, idx):
172
+ prompt = self.data[idx]["prompt"]
173
+ response = self.data[idx]["response"]
174
+ inputs = self.tokenizer(f"{prompt} {response}", return_tensors="pt", padding="max_length", max_length=128, truncation=True)
175
+ labels = inputs["input_ids"].clone()
176
+ labels[0, :len(self.tokenizer(prompt)["input_ids"][0])] = -100
177
+ return {"input_ids": inputs["input_ids"][0], "attention_mask": inputs["attention_mask"][0], "labels": labels[0]}
178
+ data = []
179
+ with open("temp.csv", "wb") as f:
180
+ f.write(uploaded_csv.read())
181
+ with open("temp.csv", "r") as f:
182
+ reader = csv.DictReader(f)
183
+ for row in reader:
184
+ data.append({"prompt": row["prompt"], "response": row["response"]})
185
+ dataset = SFTDataset(data, st.session_state['builder'].tokenizer)
186
+ dataloader = DataLoader(dataset, batch_size=2)
187
+ optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
188
+ st.session_state['builder'].model.train()
189
+ for _ in range(3): # Simplified epochs
190
+ for batch in dataloader:
191
+ optimizer.zero_grad()
192
+ outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
193
+ outputs.loss.backward()
194
+ optimizer.step()
195
+ st.success("NLP tuned! 🎉")
196
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
197
+ st.subheader("CV Tune 🎨")
198
+ uploaded_files = st.file_uploader("Upload Images", type=["png", "jpg"], accept_multiple_files=True, key="cv_upload")
199
+ text_input = st.text_area("Text (one per image)", "Bat Neon\nIron Glow", key="cv_text")
200
+ if uploaded_files and st.button("Tune CV 🔄"):
201
+ import torch
202
+ images = [Image.open(f).convert("RGB") for f in uploaded_files]
203
+ texts = text_input.splitlines()[:len(images)]
204
+ optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
205
+ st.session_state['builder'].pipeline.unet.train()
206
+ for _ in range(3): # Simplified epochs
207
+ for img, text in zip(images, texts):
208
+ optimizer.zero_grad()
209
+ latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
210
+ noise = torch.randn_like(latents)
211
+ timesteps = torch.randint(0, 1000, (1,), device=latents.device)
212
+ noisy_latents = st.session_state['builder'].pipeline.scheduler.add_noise(latents, noise, timesteps)
213
+ text_emb = st.session_state['builder'].pipeline.text_encoder(st.session_state['builder'].pipeline.tokenizer(text, return_tensors="pt").input_ids.to(st.session_state['builder'].pipeline.device))[0]
214
+ pred_noise = st.session_state['builder'].pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_emb).sample
215
+ loss = torch.nn.functional.mse_loss(pred_noise, noise)
216
+ loss.backward()
217
+ optimizer.step()
218
+ for img, text in zip(images, texts):
219
+ filename = generate_filename(text)
220
+ img.save(filename)
221
+ st.success("CV tuned! 🎉")
222
 
223
+ with tab3:
224
+ st.header("Test Titans 🧪 (Showtime!)")
225
  if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
226
  st.warning("Load a Titan first! ⚠️")
227
  else:
228
  if isinstance(st.session_state['builder'], ModelBuilder):
229
+ st.subheader("NLP Test 🧠")
230
+ prompt = st.text_area("Prompt", "What’s a superhero party?", key="nlp_test")
231
+ if st.button("Test NLP ▶️"):
232
+ result = st.session_state['builder'].evaluate(prompt)
233
+ st.write(f"**Answer**: {result}")
 
234
  elif isinstance(st.session_state['builder'], DiffusionBuilder):
235
+ st.subheader("CV Test 🎨")
236
+ prompt = st.text_area("Prompt", "Neon Batman", key="cv_test")
237
+ if st.button("Test CV ▶️"):
238
+ with st.spinner("Generating... ⏳"):
239
+ img = st.session_state['builder'].generate(prompt)
240
+ st.image(img, caption="Generated Art")
241
 
242
+ with tab4:
243
+ st.header("Camera Snap 📷 (Live Action!)")
244
+ ctx = webrtc_streamer(key="camera", video_transformer_factory=VideoSnapshot, rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]})
245
+ if ctx.video_transformer:
246
+ snapshot_text = st.text_input("Snapshot Text", "Live Snap")
247
+ if st.button("Snap It! 📸"):
248
+ snapshot = ctx.video_transformer.take_snapshot()
249
+ if snapshot:
250
+ filename = generate_filename(snapshot_text)
251
+ snapshot.save(filename)
252
+ st.image(snapshot, caption=filename)
253
+ st.success("Snapped! 🎉")
 
 
 
 
 
 
 
254
 
255
+ # Demo Dataset
256
+ st.subheader("Demo CV Dataset 🎨")
257
+ demo_texts = ["Bat Neon", "Iron Glow", "Thor Spark"]
258
+ demo_images = [generate_filename(t) for t in demo_texts]
259
+ for img, text in zip(demo_images, demo_texts):
260
  if not os.path.exists(img):
261
  Image.new("RGB", (100, 100)).save(img)
262
+ st.code("\n".join([f"{i+1}. {t} -> {img}" for i, (t, img) in enumerate(zip(demo_texts, demo_images))]), language="text")
 
 
 
 
 
 
 
 
 
263
  if st.button("Download Demo CSV 📝"):
264
+ csv_path = f"demo_cv_{int(time.time())}.csv"
265
  with open(csv_path, "w", newline="") as f:
266
  writer = csv.writer(f)
267
  writer.writerow(["image", "text"])