Update app.py
Browse files
app.py
CHANGED
@@ -49,6 +49,8 @@ if 'builder' not in st.session_state:
|
|
49 |
st.session_state['builder'] = None
|
50 |
if 'model_loaded' not in st.session_state:
|
51 |
st.session_state['model_loaded'] = False
|
|
|
|
|
52 |
|
53 |
# Model Configuration Classes
|
54 |
@dataclass
|
@@ -191,18 +193,19 @@ class DiffusionBuilder:
|
|
191 |
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
192 |
optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
|
193 |
self.pipeline.unet.train()
|
|
|
194 |
for epoch in range(epochs):
|
195 |
with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
|
196 |
total_loss = 0
|
197 |
for batch in dataloader:
|
198 |
optimizer.zero_grad()
|
199 |
-
image = batch["image"][0].to(
|
200 |
text = batch["text"][0]
|
201 |
-
latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(
|
202 |
-
noise = torch.randn_like(
|
203 |
timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
|
204 |
noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
|
205 |
-
text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(
|
206 |
pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
|
207 |
loss = torch.nn.functional.mse_loss(pred_noise, noise)
|
208 |
loss.backward()
|
@@ -225,7 +228,7 @@ def generate_filename(sequence, ext="png"):
|
|
225 |
import pytz
|
226 |
central = pytz.timezone('US/Central')
|
227 |
dt = datetime.now(central)
|
228 |
-
return f"{dt.strftime('%m-%d-%Y-%I-%M-%p')}.{ext}"
|
229 |
|
230 |
def get_download_link(file_path, mime_type="text/plain", label="Download"):
|
231 |
with open(file_path, 'rb') as f:
|
@@ -244,8 +247,7 @@ def get_model_files(model_type="causal_lm"):
|
|
244 |
return [d for d in glob.glob(path) if os.path.isdir(d)]
|
245 |
|
246 |
def get_gallery_files(file_types):
|
247 |
-
|
248 |
-
return files
|
249 |
|
250 |
def update_gallery():
|
251 |
media_files = get_gallery_files(["png"])
|
@@ -337,11 +339,19 @@ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
|
|
337 |
st.rerun()
|
338 |
|
339 |
# Tabs
|
340 |
-
|
341 |
"Build Titan 🌱", "Camera Snap 📷",
|
342 |
"Fine-Tune Titan (NLP) 🔧", "Test Titan (NLP) 🧪", "Agentic RAG Party (NLP) 🌐",
|
343 |
"Fine-Tune Titan (CV) 🔧", "Test Titan (CV) 🧪", "Agentic RAG Party (CV) 🌐"
|
344 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
with tab1:
|
347 |
st.header("Build Titan 🌱")
|
@@ -350,9 +360,9 @@ with tab1:
|
|
350 |
["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
|
351 |
["stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"])
|
352 |
model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
|
353 |
-
domain = st.text_input("Target Domain", "general", help="Where will your Titan flex its muscles? 💪")
|
354 |
if st.button("Download Model ⬇️"):
|
355 |
-
config =
|
356 |
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
|
357 |
builder.load_model(base_model, config)
|
358 |
builder.save_model(config.model_path)
|
|
|
49 |
st.session_state['builder'] = None
|
50 |
if 'model_loaded' not in st.session_state:
|
51 |
st.session_state['model_loaded'] = False
|
52 |
+
if 'active_tab' not in st.session_state:
|
53 |
+
st.session_state['active_tab'] = "Build Titan 🌱"
|
54 |
|
55 |
# Model Configuration Classes
|
56 |
@dataclass
|
|
|
193 |
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
194 |
optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
|
195 |
self.pipeline.unet.train()
|
196 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
197 |
for epoch in range(epochs):
|
198 |
with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
|
199 |
total_loss = 0
|
200 |
for batch in dataloader:
|
201 |
optimizer.zero_grad()
|
202 |
+
image = batch["image"][0].to(device)
|
203 |
text = batch["text"][0]
|
204 |
+
latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(device)).latent_dist.sample()
|
205 |
+
noise = torch.randn_like(latents)
|
206 |
timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
|
207 |
noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
|
208 |
+
text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(device))[0]
|
209 |
pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
|
210 |
loss = torch.nn.functional.mse_loss(pred_noise, noise)
|
211 |
loss.backward()
|
|
|
228 |
import pytz
|
229 |
central = pytz.timezone('US/Central')
|
230 |
dt = datetime.now(central)
|
231 |
+
return f"{dt.strftime('%m-%d-%Y-%I-%M-%S-%p')}.{ext}"
|
232 |
|
233 |
def get_download_link(file_path, mime_type="text/plain", label="Download"):
|
234 |
with open(file_path, 'rb') as f:
|
|
|
247 |
return [d for d in glob.glob(path) if os.path.isdir(d)]
|
248 |
|
249 |
def get_gallery_files(file_types):
|
250 |
+
return sorted(list(set(f for ext in file_types for f in glob.glob(f"*.{ext}")))) # Remove duplicates and sort
|
|
|
251 |
|
252 |
def update_gallery():
|
253 |
media_files = get_gallery_files(["png"])
|
|
|
339 |
st.rerun()
|
340 |
|
341 |
# Tabs
|
342 |
+
tabs = [
|
343 |
"Build Titan 🌱", "Camera Snap 📷",
|
344 |
"Fine-Tune Titan (NLP) 🔧", "Test Titan (NLP) 🧪", "Agentic RAG Party (NLP) 🌐",
|
345 |
"Fine-Tune Titan (CV) 🔧", "Test Titan (CV) 🧪", "Agentic RAG Party (CV) 🌐"
|
346 |
+
]
|
347 |
+
tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8 = st.tabs(tabs)
|
348 |
+
|
349 |
+
# Log Tab Switches
|
350 |
+
for i, tab in enumerate(tabs):
|
351 |
+
if st.session_state['active_tab'] != tab and st.session_state.get(f'tab{i}_active', False):
|
352 |
+
logger.info(f"Switched to tab: {tab}")
|
353 |
+
st.session_state['active_tab'] = tab
|
354 |
+
st.session_state[f'tab{i}_active'] = (st.session_state['active_tab'] == tab)
|
355 |
|
356 |
with tab1:
|
357 |
st.header("Build Titan 🌱")
|
|
|
360 |
["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
|
361 |
["stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"])
|
362 |
model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
|
363 |
+
domain = st.text_input("Target Domain", "general", help="Where will your Titan flex its muscles? 💪") if model_type == "Causal LM" else None
|
364 |
if st.button("Download Model ⬇️"):
|
365 |
+
config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain) if model_type == "Causal LM" else DiffusionConfig(name=model_name, base_model=base_model, size="small")
|
366 |
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
|
367 |
builder.load_model(base_model, config)
|
368 |
builder.save_model(config.model_path)
|