jbilcke-hf HF Staff commited on
Commit
2264c6e
·
1 Parent(s): 7bce2a2

tentative fix

Browse files
vms/ui/app_ui.py CHANGED
@@ -399,12 +399,11 @@ class AppUI:
399
  outputs=[
400
  self.project_tabs["train_tab"].components["status_box"],
401
  self.project_tabs["train_tab"].components["log_box"],
402
- self.project_tabs["train_tab"].components["current_task_box"] if "current_task_box" in self.project_tabs["train_tab"].components else None,
403
- self.project_tabs["manage_tab"].components["download_model_btn"],
404
- self.project_tabs["manage_tab"].components["download_checkpoint_btn"]
405
  ]
406
  )
407
 
 
408
  # Button update timer for button components (every 1 second)
409
  button_timer = gr.Timer(value=1)
410
  button_outputs = [
 
399
  outputs=[
400
  self.project_tabs["train_tab"].components["status_box"],
401
  self.project_tabs["train_tab"].components["log_box"],
402
+ self.project_tabs["train_tab"].components["current_task_box"] if "current_task_box" in self.project_tabs["train_tab"].components else None
 
 
403
  ]
404
  )
405
 
406
+
407
  # Button update timer for button components (every 1 second)
408
  button_timer = gr.Timer(value=1)
409
  button_outputs = [
vms/ui/project/services/training.py CHANGED
@@ -1823,12 +1823,9 @@ class TrainingService:
1823
  try:
1824
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1825
  if not checkpoints:
1826
- return "📥 Download checkpoints (not available)"
1827
 
1828
- # Get the latest checkpoint by step number
1829
- latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1830
- step_num = int(latest_checkpoint.name.split("_")[-1])
1831
- return f"📥 Download checkpoints (step {step_num})"
1832
  except Exception as e:
1833
  logger.warning(f"Error getting checkpoint info for button text: {e}")
1834
- return "📥 Download checkpoints (not available)"
 
1823
  try:
1824
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1825
  if not checkpoints:
1826
+ return "No checkpoints available"
1827
 
1828
+ return f"💽 Download checkpoints"
 
 
 
1829
  except Exception as e:
1830
  logger.warning(f"Error getting checkpoint info for button text: {e}")
1831
+ return "No checkpoints available"
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -25,50 +25,6 @@ class ManageTab(BaseTab):
25
  self.id = "manage_tab"
26
  self.title = "5️⃣ Storage"
27
 
28
- def get_download_button_text(self) -> str:
29
- """Get the dynamic text for the download button based on current model state"""
30
- try:
31
- model_info = self.app.training.get_model_output_info()
32
- if model_info["path"] and model_info["steps"]:
33
- return f"🧠 Download weights ({model_info['steps']} steps)"
34
- elif model_info["path"]:
35
- return "🧠 Download weights (.safetensors)"
36
- else:
37
- return "🧠 Download weights (not available)"
38
- except Exception as e:
39
- logger.warning(f"Error getting model info for button text: {e}")
40
- return "🧠 Download weights (.safetensors)"
41
-
42
- def get_checkpoint_button_text(self) -> str:
43
- """Get the dynamic text for the download checkpoint button"""
44
- try:
45
- return self.app.training.get_checkpoint_button_text()
46
- except Exception as e:
47
- logger.warning(f"Error getting checkpoint button text: {e}")
48
- return "📥 Download checkpoints (not available)"
49
-
50
- def update_download_button_text(self) -> gr.update:
51
- """Update the download button text"""
52
- return gr.update(value=self.get_download_button_text())
53
-
54
- def update_checkpoint_button_text(self) -> gr.update:
55
- """Update the checkpoint button text"""
56
- return gr.update(value=self.get_checkpoint_button_text())
57
-
58
- def update_both_download_buttons(self) -> Tuple[gr.update, gr.update]:
59
- """Update both download button texts"""
60
- return (
61
- gr.update(value=self.get_download_button_text()),
62
- gr.update(value=self.get_checkpoint_button_text())
63
- )
64
-
65
- def download_and_update_button(self):
66
- """Handle download and return updated button with current text"""
67
- # Get the safetensors path for download
68
- path = self.app.training.get_model_output_safetensors()
69
- # For DownloadButton, we need to return the file path directly for download
70
- # The button text will be updated on next render
71
- return path
72
 
73
  def create(self, parent=None) -> gr.TabItem:
74
  """Create the Manage tab UI components"""
@@ -90,19 +46,19 @@ class ManageTab(BaseTab):
90
  gr.Markdown("📦 Training dataset download disabled for large datasets")
91
 
92
  self.components["download_model_btn"] = gr.DownloadButton(
93
- self.get_download_button_text(),
94
  variant="secondary",
95
  size="lg"
96
  )
97
 
98
  self.components["download_checkpoint_btn"] = gr.DownloadButton(
99
- self.get_checkpoint_button_text(),
100
  variant="secondary",
101
  size="lg"
102
  )
103
 
104
  self.components["download_output_btn"] = gr.DownloadButton(
105
- "📁 Download output directory (.zip)",
106
  variant="secondary",
107
  size="lg",
108
  visible=False
 
25
  self.id = "manage_tab"
26
  self.title = "5️⃣ Storage"
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def create(self, parent=None) -> gr.TabItem:
30
  """Create the Manage tab UI components"""
 
46
  gr.Markdown("📦 Training dataset download disabled for large datasets")
47
 
48
  self.components["download_model_btn"] = gr.DownloadButton(
49
+ "🧠 Download LoRA weights",
50
  variant="secondary",
51
  size="lg"
52
  )
53
 
54
  self.components["download_checkpoint_btn"] = gr.DownloadButton(
55
+ "💽 Download Checkpoints",
56
  variant="secondary",
57
  size="lg"
58
  )
59
 
60
  self.components["download_output_btn"] = gr.DownloadButton(
61
+ "📁 Download output/ (.zip)",
62
  variant="secondary",
63
  size="lg",
64
  visible=False
vms/ui/project/tabs/train_tab.py CHANGED
@@ -494,12 +494,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
494
  save_iterations, repo_id, progress
495
  )
496
 
497
- # Update download button texts
498
- manage_tab = self.app.tabs["manage_tab"]
499
- download_btn_text = gr.update(value=manage_tab.get_download_button_text())
500
- checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
501
-
502
- return status, logs, download_btn_text, checkpoint_btn_text
503
 
504
  def handle_resume_training(
505
  self, model_type, model_version, training_type,
@@ -511,10 +506,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
511
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
512
 
513
  if not checkpoints:
514
- manage_tab = self.app.tabs["manage_tab"]
515
- download_btn_text = gr.update(value=manage_tab.get_download_button_text())
516
- checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
517
- return "No checkpoints found to resume from", "Please start a new training session instead", download_btn_text, checkpoint_btn_text
518
 
519
  self.app.training.append_log(f"Resuming training from latest checkpoint")
520
 
@@ -526,12 +518,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
526
  resume_from_checkpoint="latest"
527
  )
528
 
529
- # Update download button texts
530
- manage_tab = self.app.tabs["manage_tab"]
531
- download_btn_text = gr.update(value=manage_tab.get_download_button_text())
532
- checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
533
-
534
- return status, logs, download_btn_text, checkpoint_btn_text
535
 
536
  def handle_start_from_lora_training(
537
  self, model_type, model_version, training_type,
@@ -542,26 +529,22 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
542
  # Find the latest LoRA weights
543
  lora_weights_path = self.app.output_path / "lora_weights"
544
 
545
- manage_tab = self.app.tabs["manage_tab"]
546
- download_btn_text = gr.update(value=manage_tab.get_download_button_text())
547
- checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
548
-
549
  if not lora_weights_path.exists():
550
- return "No LoRA weights found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
551
 
552
  # Find the latest LoRA checkpoint directory
553
  lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
554
  key=lambda x: int(x.name), reverse=True)
555
 
556
  if not lora_dirs:
557
- return "No LoRA weight directories found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
558
 
559
  latest_lora_dir = lora_dirs[0]
560
 
561
  # Verify the LoRA weights file exists
562
  lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
563
  if not lora_weights_file.exists():
564
- return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory", download_btn_text, checkpoint_btn_text
565
 
566
  # Clear checkpoints to start fresh (but keep LoRA weights)
567
  for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
@@ -582,11 +565,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
582
  save_iterations, repo_id, progress,
583
  )
584
 
585
- # Update download button texts
586
- download_btn_text = gr.update(value=manage_tab.get_download_button_text())
587
- checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
588
-
589
- return status, logs, download_btn_text, checkpoint_btn_text
590
 
591
  def connect_events(self) -> None:
592
  """Connect event handlers to UI components"""
@@ -769,9 +748,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
769
  ],
770
  outputs=[
771
  self.components["status_box"],
772
- self.components["log_box"],
773
- self.app.tabs["manage_tab"].components["download_model_btn"],
774
- self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
775
  ]
776
  )
777
 
@@ -791,9 +768,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
791
  ],
792
  outputs=[
793
  self.components["status_box"],
794
- self.components["log_box"],
795
- self.app.tabs["manage_tab"].components["download_model_btn"],
796
- self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
797
  ]
798
  )
799
 
@@ -813,9 +788,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
813
  ],
814
  outputs=[
815
  self.components["status_box"],
816
- self.components["log_box"],
817
- self.app.tabs["manage_tab"].components["download_model_btn"],
818
- self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
819
  ]
820
  )
821
 
@@ -831,9 +804,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
831
  self.components["current_task_box"],
832
  self.components["start_btn"],
833
  self.components["stop_btn"],
834
- third_btn,
835
- self.app.tabs["manage_tab"].components["download_model_btn"],
836
- self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
837
  ]
838
  )
839
 
@@ -845,9 +816,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
845
  self.components["current_task_box"],
846
  self.components["start_btn"],
847
  self.components["stop_btn"],
848
- third_btn,
849
- self.app.tabs["manage_tab"].components["download_model_btn"],
850
- self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
851
  ]
852
  )
853
 
@@ -1201,12 +1170,7 @@ Full finetune mode trains all parameters of the model, requiring more VRAM but p
1201
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
1202
  current_task = self.app.log_parser.get_current_task_display()
1203
 
1204
- # Update download button texts
1205
- manage_tab = self.app.tabs["manage_tab"]
1206
- download_btn_text = gr.update(value=manage_tab.get_download_button_text())
1207
- checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
1208
-
1209
- return message, logs, current_task, download_btn_text, checkpoint_btn_text
1210
 
1211
  def get_button_updates(self):
1212
  """Get button updates (with variant property)"""
 
494
  save_iterations, repo_id, progress
495
  )
496
 
497
+ return status, logs
 
 
 
 
 
498
 
499
  def handle_resume_training(
500
  self, model_type, model_version, training_type,
 
506
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
507
 
508
  if not checkpoints:
509
+ return "No checkpoints found to resume from", "Please start a new training session instead"
 
 
 
510
 
511
  self.app.training.append_log(f"Resuming training from latest checkpoint")
512
 
 
518
  resume_from_checkpoint="latest"
519
  )
520
 
521
+ return status, logs
 
 
 
 
 
522
 
523
  def handle_start_from_lora_training(
524
  self, model_type, model_version, training_type,
 
529
  # Find the latest LoRA weights
530
  lora_weights_path = self.app.output_path / "lora_weights"
531
 
 
 
 
 
532
  if not lora_weights_path.exists():
533
+ return "No LoRA weights found", "Please train a model first or start a new training session"
534
 
535
  # Find the latest LoRA checkpoint directory
536
  lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
537
  key=lambda x: int(x.name), reverse=True)
538
 
539
  if not lora_dirs:
540
+ return "No LoRA weight directories found", "Please train a model first or start a new training session"
541
 
542
  latest_lora_dir = lora_dirs[0]
543
 
544
  # Verify the LoRA weights file exists
545
  lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
546
  if not lora_weights_file.exists():
547
+ return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
548
 
549
  # Clear checkpoints to start fresh (but keep LoRA weights)
550
  for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
 
565
  save_iterations, repo_id, progress,
566
  )
567
 
568
+ return status, logs
 
 
 
 
569
 
570
  def connect_events(self) -> None:
571
  """Connect event handlers to UI components"""
 
748
  ],
749
  outputs=[
750
  self.components["status_box"],
751
+ self.components["log_box"]
 
 
752
  ]
753
  )
754
 
 
768
  ],
769
  outputs=[
770
  self.components["status_box"],
771
+ self.components["log_box"]
 
 
772
  ]
773
  )
774
 
 
788
  ],
789
  outputs=[
790
  self.components["status_box"],
791
+ self.components["log_box"]
 
 
792
  ]
793
  )
794
 
 
804
  self.components["current_task_box"],
805
  self.components["start_btn"],
806
  self.components["stop_btn"],
807
+ third_btn
 
 
808
  ]
809
  )
810
 
 
816
  self.components["current_task_box"],
817
  self.components["start_btn"],
818
  self.components["stop_btn"],
819
+ third_btn
 
 
820
  ]
821
  )
822
 
 
1170
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
1171
  current_task = self.app.log_parser.get_current_task_display()
1172
 
1173
+ return message, logs, current_task
 
 
 
 
 
1174
 
1175
  def get_button_updates(self):
1176
  """Get button updates (with variant property)"""