awacke1 commited on
Commit
a1bc718
·
verified ·
1 Parent(s): 07943e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -11
app.py CHANGED
@@ -49,6 +49,8 @@ if 'builder' not in st.session_state:
49
  st.session_state['builder'] = None
50
  if 'model_loaded' not in st.session_state:
51
  st.session_state['model_loaded'] = False
 
 
52
 
53
  # Model Configuration Classes
54
  @dataclass
@@ -191,18 +193,19 @@ class DiffusionBuilder:
191
  dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
192
  optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
193
  self.pipeline.unet.train()
 
194
  for epoch in range(epochs):
195
  with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
196
  total_loss = 0
197
  for batch in dataloader:
198
  optimizer.zero_grad()
199
- image = batch["image"][0].to(self.pipeline.device)
200
  text = batch["text"][0]
201
- latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)).latent_dist.sample()
202
- noise = torch.randn_like(latant)
203
  timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
204
  noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
205
- text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
206
  pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
207
  loss = torch.nn.functional.mse_loss(pred_noise, noise)
208
  loss.backward()
@@ -225,7 +228,7 @@ def generate_filename(sequence, ext="png"):
225
  import pytz
226
  central = pytz.timezone('US/Central')
227
  dt = datetime.now(central)
228
- return f"{dt.strftime('%m-%d-%Y-%I-%M-%p')}.{ext}"
229
 
230
  def get_download_link(file_path, mime_type="text/plain", label="Download"):
231
  with open(file_path, 'rb') as f:
@@ -244,8 +247,7 @@ def get_model_files(model_type="causal_lm"):
244
  return [d for d in glob.glob(path) if os.path.isdir(d)]
245
 
246
  def get_gallery_files(file_types):
247
- files = sorted(list(set(f for ext in file_types for f in glob.glob(f"*.{ext}")))) # Remove duplicates and sort
248
- return files
249
 
250
  def update_gallery():
251
  media_files = get_gallery_files(["png"])
@@ -337,11 +339,19 @@ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
337
  st.rerun()
338
 
339
  # Tabs
340
- tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8 = st.tabs([
341
  "Build Titan 🌱", "Camera Snap 📷",
342
  "Fine-Tune Titan (NLP) 🔧", "Test Titan (NLP) 🧪", "Agentic RAG Party (NLP) 🌐",
343
  "Fine-Tune Titan (CV) 🔧", "Test Titan (CV) 🧪", "Agentic RAG Party (CV) 🌐"
344
- ])
 
 
 
 
 
 
 
 
345
 
346
  with tab1:
347
  st.header("Build Titan 🌱")
@@ -350,9 +360,9 @@ with tab1:
350
  ["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
351
  ["stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"])
352
  model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
353
- domain = st.text_input("Target Domain", "general", help="Where will your Titan flex its muscles? 💪")
354
  if st.button("Download Model ⬇️"):
355
- config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain if model_type == "Causal LM" else None)
356
  builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
357
  builder.load_model(base_model, config)
358
  builder.save_model(config.model_path)
 
49
  st.session_state['builder'] = None
50
  if 'model_loaded' not in st.session_state:
51
  st.session_state['model_loaded'] = False
52
+ if 'active_tab' not in st.session_state:
53
+ st.session_state['active_tab'] = "Build Titan 🌱"
54
 
55
  # Model Configuration Classes
56
  @dataclass
 
193
  dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
194
  optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
195
  self.pipeline.unet.train()
196
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
197
  for epoch in range(epochs):
198
  with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
199
  total_loss = 0
200
  for batch in dataloader:
201
  optimizer.zero_grad()
202
+ image = batch["image"][0].to(device)
203
  text = batch["text"][0]
204
+ latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(device)).latent_dist.sample()
205
+ noise = torch.randn_like(latents)
206
  timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
207
  noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
208
+ text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(device))[0]
209
  pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
210
  loss = torch.nn.functional.mse_loss(pred_noise, noise)
211
  loss.backward()
 
228
  import pytz
229
  central = pytz.timezone('US/Central')
230
  dt = datetime.now(central)
231
+ return f"{dt.strftime('%m-%d-%Y-%I-%M-%S-%p')}.{ext}"
232
 
233
  def get_download_link(file_path, mime_type="text/plain", label="Download"):
234
  with open(file_path, 'rb') as f:
 
247
  return [d for d in glob.glob(path) if os.path.isdir(d)]
248
 
249
  def get_gallery_files(file_types):
250
+ return sorted(list(set(f for ext in file_types for f in glob.glob(f"*.{ext}")))) # Remove duplicates and sort
 
251
 
252
  def update_gallery():
253
  media_files = get_gallery_files(["png"])
 
339
  st.rerun()
340
 
341
  # Tabs
342
+ tabs = [
343
  "Build Titan 🌱", "Camera Snap 📷",
344
  "Fine-Tune Titan (NLP) 🔧", "Test Titan (NLP) 🧪", "Agentic RAG Party (NLP) 🌐",
345
  "Fine-Tune Titan (CV) 🔧", "Test Titan (CV) 🧪", "Agentic RAG Party (CV) 🌐"
346
+ ]
347
+ tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8 = st.tabs(tabs)
348
+
349
+ # Log Tab Switches
350
+ for i, tab in enumerate(tabs):
351
+ if st.session_state['active_tab'] != tab and st.session_state.get(f'tab{i}_active', False):
352
+ logger.info(f"Switched to tab: {tab}")
353
+ st.session_state['active_tab'] = tab
354
+ st.session_state[f'tab{i}_active'] = (st.session_state['active_tab'] == tab)
355
 
356
  with tab1:
357
  st.header("Build Titan 🌱")
 
360
  ["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
361
  ["stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"])
362
  model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
363
+ domain = st.text_input("Target Domain", "general", help="Where will your Titan flex its muscles? 💪") if model_type == "Causal LM" else None
364
  if st.button("Download Model ⬇️"):
365
+ config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain) if model_type == "Causal LM" else DiffusionConfig(name=model_name, base_model=base_model, size="small")
366
  builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
367
  builder.load_model(base_model, config)
368
  builder.save_model(config.model_path)