DonImages commited on
Commit
8894709
·
verified ·
1 Parent(s): b940d83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -75,31 +75,31 @@ def train_lora(image_folder, metadata):
75
  criterion = nn.CrossEntropyLoss() # Update this if your task changes
76
  optimizer = optim.Adam(model.parameters(), lr=0.001)
77
 
78
- # Training loop
79
- num_epochs = 5 # Adjust the number of epochs based on your needs
80
- for epoch in range(num_epochs):
81
- print(f"Epoch {epoch + 1}/{num_epochs}")
82
- for batch_idx, (images, descriptions) in enumerate(dataloader):
83
- # Convert descriptions to a numerical format (if applicable)
84
- labels = torch.randint(0, 100, (images.size(0),)) # Placeholder labels
85
-
86
- # Forward pass
87
- outputs = model(images)
88
- loss = criterion(outputs, labels)
89
-
90
- # Backward pass
91
- optimizer.zero_grad()
92
- loss.backward()
93
- optimizer.step()
94
-
95
- if batch_idx % 10 == 0: # Log every 10 batches
96
- print(f"Batch {batch_idx}, Loss: {loss.item()}")
97
-
98
- # Save the trained model
99
- torch.save(model.state_dict(), "lora_model.pth")
100
- print("Model saved as lora_model.pth")
101
-
102
- print("Training completed.")
103
 
104
  # Gradio App
105
  def start_training_gradio():
 
75
  criterion = nn.CrossEntropyLoss() # Update this if your task changes
76
  optimizer = optim.Adam(model.parameters(), lr=0.001)
77
 
78
+ # Training loop
79
+ num_epochs = 5 # Adjust the number of epochs based on your needs
80
+ for epoch in range(num_epochs):
81
+ print(f"Epoch {epoch + 1}/{num_epochs}")
82
+ for batch_idx, (images, descriptions) in enumerate(dataloader):
83
+ # Convert descriptions to a numerical format (if applicable)
84
+ labels = torch.randint(0, 100, (images.size(0),)) # Placeholder labels
85
+
86
+ # Forward pass
87
+ outputs = model(images)
88
+ loss = criterion(outputs, labels)
89
+
90
+ # Backward pass
91
+ optimizer.zero_grad()
92
+ loss.backward()
93
+ optimizer.step()
94
+
95
+ if batch_idx % 10 == 0: # Log every 10 batches
96
+ print(f"Batch {batch_idx}, Loss: {loss.item()}")
97
+
98
+ # Save the trained model
99
+ torch.save(model.state_dict(), "lora_model.pth")
100
+ print("Model saved as lora_model.pth")
101
+
102
+ print("Training completed.")
103
 
104
  # Gradio App
105
  def start_training_gradio():