fantos commited on
Commit
dc5358b
·
verified ·
1 Parent(s): 2cdbfb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -5,7 +5,7 @@ import gradio as gr
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
  import os
8
- from huggingface_hub import hf_hub_download
9
  import torch.nn.functional as F
10
 
11
  # Check for CUDA availability but fallback to CPU
@@ -88,21 +88,26 @@ class Generator(nn.Module):
88
 
89
  # Initialize models
90
  def load_models():
91
- model1 = Generator(3, 1, 3).to(device)
92
- model2 = Generator(3, 1, 3).to(device)
93
-
94
- # Download models from HuggingFace Hub
95
- model1_path = hf_hub_download(repo_id="your-hf-repo/line-drawing", filename="model.pth")
96
- model2_path = hf_hub_download(repo_id="your-hf-repo/line-drawing", filename="model2.pth")
97
-
98
- model1.load_state_dict(torch.load(model1_path, map_location=device))
99
- model2.load_state_dict(torch.load(model2_path, map_location=device))
100
-
101
- model1.eval()
102
- model2.eval()
103
- return model1, model2
 
104
 
105
- model1, model2 = load_models()
 
 
 
 
106
 
107
  def apply_style_transfer(img, strength=1.0):
108
  """Apply artistic style transfer effect"""
 
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
  import os
8
+
9
  import torch.nn.functional as F
10
 
11
  # Check for CUDA availability but fallback to CPU
 
88
 
89
  # Initialize models
90
  def load_models():
91
+ try:
92
+ model1 = Generator(3, 1, 3).to(device)
93
+ model2 = Generator(3, 1, 3).to(device)
94
+
95
+ # Load local model files
96
+ model1.load_state_dict(torch.load('model.pth', map_location=device))
97
+ model2.load_state_dict(torch.load('model2.pth', map_location=device))
98
+
99
+ model1.eval()
100
+ model2.eval()
101
+ return model1, model2
102
+ except Exception as e:
103
+ print(f"Error loading models: {str(e)}")
104
+ raise gr.Error("Failed to load models. Please check if model files exist in the correct location.")
105
 
106
+ try:
107
+ model1, model2 = load_models()
108
+ except Exception as e:
109
+ print(f"Model initialization failed: {str(e)}")
110
+ model1 = model2 = None
111
 
112
  def apply_style_transfer(img, strength=1.0):
113
  """Apply artistic style transfer effect"""