DonImages commited on
Commit
a9b26e4
·
verified ·
1 Parent(s): 962febd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -120
app.py CHANGED
@@ -1,122 +1,66 @@
 
1
  import torch
2
- from torch import nn, optim
3
- from torch.utils.data import DataLoader, Dataset
4
- from torchvision import transforms, datasets, models
5
- from PIL import Image
6
- import json
7
  import os
8
- import gradio as gr
9
-
10
- # Paths
11
- image_folder = "Images/"
12
- metadata_file = "descriptions.json"
13
-
14
- # Define the function to load metadata
15
- def load_metadata(metadata_file):
16
- with open(metadata_file, 'r') as f:
17
- metadata = json.load(f)
18
- return metadata
19
-
20
- # Custom Dataset Class
21
- class ImageDescriptionDataset(Dataset):
22
- def __init__(self, image_folder, metadata):
23
- self.image_folder = image_folder
24
- self.metadata = metadata
25
- self.image_names = list(metadata.keys()) # List of image filenames
26
- self.transform = transforms.Compose([
27
- transforms.Resize((512, 512)),
28
- transforms.ToTensor(),
29
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
- ])
31
-
32
- def __len__(self):
33
- return len(self.image_names)
34
-
35
- def __getitem__(self, idx):
36
- image_name = self.image_names[idx]
37
- image_path = os.path.join(self.image_folder, image_name)
38
- image = Image.open(image_path).convert("RGB")
39
- description = self.metadata[image_name]
40
- image = self.transform(image)
41
- return image, description
42
-
43
- # LoRA Layer Implementation
44
- class LoRALayer(nn.Module):
45
- def __init__(self, original_layer, rank=4):
46
- super(LoRALayer, self).__init__()
47
- self.original_layer = original_layer
48
- self.rank = rank
49
- self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False)
50
- self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False)
51
-
52
- def forward(self, x):
53
- return self.original_layer(x) + self.lora_down(self.lora_up(x))
54
-
55
- # LoRA Model Class
56
- class LoRAModel(nn.Module):
57
- def __init__(self):
58
- super(LoRAModel, self).__init__()
59
- self.backbone = models.resnet18(pretrained=True) # Base model
60
- self.backbone.fc = LoRALayer(self.backbone.fc) # Replace the final layer with LoRA
61
-
62
- def forward(self, x):
63
- return self.backbone(x)
64
-
65
- # Training Function
66
- def train_lora(image_folder, metadata):
67
- print("Starting LoRA training process...")
68
-
69
- # Create dataset and dataloader
70
- dataset = ImageDescriptionDataset(image_folder, metadata)
71
- dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
72
-
73
- # Initialize model, loss function, and optimizer
74
- model = LoRAModel()
75
- criterion = nn.CrossEntropyLoss() # Update this if your task changes
76
- optimizer = optim.Adam(model.parameters(), lr=0.001)
77
-
78
- # Training loop
79
- num_epochs = 5 # Adjust the number of epochs based on your needs
80
- for epoch in range(num_epochs):
81
- print(f"Epoch {epoch + 1}/{num_epochs}")
82
- for batch_idx, (images, descriptions) in enumerate(dataloader):
83
- # Convert descriptions to a numerical format (if applicable)
84
- labels = torch.randint(0, 100, (images.size(0),)) # Placeholder labels
85
-
86
- # Forward pass
87
- outputs = model(images)
88
- loss = criterion(outputs, labels)
89
-
90
- # Backward pass
91
- optimizer.zero_grad()
92
- loss.backward()
93
- optimizer.step()
94
-
95
- if batch_idx % 10 == 0: # Log every 10 batches
96
- print(f"Batch {batch_idx}, Loss: {loss.item()}")
97
-
98
- # Save the trained model
99
- model_path = "lora_model.pth"
100
- torch.save(model.state_dict(), model_path)
101
- print(f"Model saved as {model_path}")
102
-
103
- print("Training completed.")
104
- return model_path # Return the path of the saved model
105
-
106
- # Gradio App
107
- def start_training_gradio():
108
- print("Loading metadata and preparing dataset...")
109
- metadata = load_metadata(metadata_file)
110
- model_path = train_lora(image_folder, metadata)
111
- return model_path # This will return the model file path for download
112
-
113
- # Gradio interface
114
- demo = gr.Interface(
115
- fn=start_training_gradio,
116
- inputs=None,
117
- outputs=gr.File(),
118
- title="Train LoRA Model",
119
- description="Fine-tune a model using LoRA for consistent image generation."
120
- )
121
-
122
- demo.launch()
 
1
+ import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusion3Pipeline
 
 
 
 
4
  import os
5
+ import spaces
6
+
7
+ # Use the token saved in secrets
8
+ hf_token = os.getenv("HF_TOKEN")
9
+
10
+ # Specify the pre-trained model ID
11
+ model_id = "stabilityai/stable-diffusion-3.5-large"
12
+
13
+ # Global variable for the pipeline (only initialized once)
14
+ pipeline = None
15
+
16
+ # Function for initializing and caching the pipeline
17
+ def initialize_pipeline():
18
+ global pipeline
19
+ if pipeline is None:
20
+ try:
21
+ # Load the pipeline with mixed precision (FP16)
22
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
23
+ model_id,
24
+ use_auth_token=hf_token,
25
+ torch_dtype=torch.float16, # Use FP16 for mixed precision
26
+ )
27
+ # Enable model offloading and attention slicing for memory efficiency
28
+ pipeline.enable_model_cpu_offload()
29
+ pipeline.enable_attention_slicing()
30
+ print("Pipeline initialized and cached.")
31
+ except Exception as e:
32
+ # Error handling for model loading issues
33
+ print(f"Error loading the model: {e}")
34
+ raise RuntimeError("Failed to initialize the model pipeline.")
35
+ return pipeline
36
+
37
+ # Function for image generation, decorated to use GPU
38
+ @spaces.GPU(duration=65)
39
+ def generate_image(prompt):
40
+ pipe = initialize_pipeline() # Initialize the pipeline (only once)
41
+ # Generate the image using the pipeline
42
+ try:
43
+ image = pipe(prompt).images[0]
44
+ except Exception as e:
45
+ # Catch errors during image generation (e.g., GPU/Memory errors)
46
+ print(f"Error during image generation: {e}")
47
+ raise RuntimeError("Image generation failed.")
48
+ return image
49
+
50
+ # Set up Gradio interface with a simple input for text and output for image
51
+ interface = gr.Interface(fn=generate_image, inputs="text", outputs="image")
52
+
53
+ # Launch the interface
54
+ interface.launch()
55
+
56
+ # Optimize device and dtype handling for CUDA or CPU
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
59
+
60
+ # Additional model validation (this is optional, more for debugging)
61
+ pipe = initialize_pipeline() # Ensure the model is initialized and cached
62
+ if not pipe or not hasattr(pipe, 'transformer'):
63
+ raise ValueError("Failed to load the model or the transformer component is missing.")
64
+
65
+ # Move the pipeline to the correct device (CUDA or CPU)
66
+ pipe = pipe.to(device)