Kiwinicki commited on
Commit
fdbc146
·
verified ·
1 Parent(s): f60843d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -52
app.py CHANGED
@@ -10,28 +10,28 @@ import torchvision.transforms as transforms
10
 
11
  photos_folder = "Photos"
12
 
13
- # Pobierz model i config
14
  repo_id = "Kiwinicki/sat2map-generator"
15
  generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
16
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
17
  model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
18
 
19
- # Dodaj ścieżkę do modelu
20
  sys.path.append(os.path.dirname(model_path))
21
  from model import Generator
22
 
23
- # Załaduj konfigurację
24
  with open(config_path, "r") as f:
25
  config_dict = json.load(f)
26
  cfg = OmegaConf.create(config_dict)
27
 
28
- # Inicjalizacja modelu
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  generator = Generator(cfg).to(device)
31
  generator.load_state_dict(torch.load(generator_path, map_location=device))
32
  generator.eval()
33
 
34
- # Transformacje
35
  transform = transforms.Compose([
36
  transforms.Resize((256, 256)),
37
  transforms.ToTensor(),
@@ -39,86 +39,86 @@ transform = transforms.Compose([
39
  ])
40
 
41
  def process_image(image):
42
- # Konwersja do tensora
 
 
 
43
  image_tensor = transform(image).unsqueeze(0).to(device)
44
 
45
- # Inferencja
46
  with torch.no_grad():
47
  output_tensor = generator(image_tensor)
48
 
49
- # Przygotowanie wyjścia
50
  output_image = output_tensor.squeeze(0).cpu()
51
- output_image = output_image * 0.5 + 0.5 # Denormalizacja
52
  output_image = transforms.ToPILImage()(output_image)
53
 
54
  return output_image
55
 
56
-
57
  def load_images_from_folder(folder):
58
  images = []
 
 
 
 
59
  for filename in os.listdir(folder):
60
  if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
61
  img_path = os.path.join(folder, filename)
62
- img = Image.open(img_path)
63
- images.append((img, filename))
 
 
 
64
  return images
65
 
66
-
67
-
68
- def load_image_from_gallery(images, index):
69
- if images and 0 <= index < len(images):
70
- image = images[index]
71
- if isinstance(image, tuple):
72
- image = image[0]
73
- return image
74
- return None
75
-
76
-
77
- def gallery_click_event(images, evt: gr.SelectData):
78
- index = evt.index
79
- selected_img = load_image_from_gallery(images, index)
80
- return selected_img
81
-
82
-
83
- def clear_image():
84
- return None
85
-
86
-
87
  def app():
88
  images = load_images_from_folder(photos_folder)
 
89
 
90
- with gr.Blocks(css=".container { background-color: white; }") as demo:
91
  with gr.Row():
92
  with gr.Column():
93
- selected_image = gr.Image(label="Input Image", type="pil")
94
  clear_button = gr.Button("Clear")
95
 
96
  with gr.Column():
97
- image_gallery = gr.Gallery(label="Image Gallery", elem_id="gallery", type="pil", value=[img for img, _ in images])
 
 
 
 
 
 
98
 
99
  with gr.Column():
100
- result_image = gr.Image(label="Result Image", type="pil")
101
-
102
- image_gallery.select(
103
- fn=gallery_click_event,
104
- inputs=image_gallery,
105
- outputs=selected_image
 
 
 
 
 
106
  )
107
-
108
- selected_image.change(
 
109
  fn=process_image,
110
- inputs=selected_image,
111
- outputs=result_image
112
  )
113
-
 
114
  clear_button.click(
115
- fn=clear_image,
116
- inputs=None,
117
- outputs=selected_image
118
  )
119
 
120
  demo.launch()
121
 
122
-
123
  if __name__ == "__main__":
124
- app()
 
10
 
11
  photos_folder = "Photos"
12
 
13
+ # Download model and config
14
  repo_id = "Kiwinicki/sat2map-generator"
15
  generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
16
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
17
  model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
18
 
19
+ # Add path to model
20
  sys.path.append(os.path.dirname(model_path))
21
  from model import Generator
22
 
23
+ # Load configuration
24
  with open(config_path, "r") as f:
25
  config_dict = json.load(f)
26
  cfg = OmegaConf.create(config_dict)
27
 
28
+ # Initialize model
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  generator = Generator(cfg).to(device)
31
  generator.load_state_dict(torch.load(generator_path, map_location=device))
32
  generator.eval()
33
 
34
+ # Transformations
35
  transform = transforms.Compose([
36
  transforms.Resize((256, 256)),
37
  transforms.ToTensor(),
 
39
  ])
40
 
41
  def process_image(image):
42
+ if image is None:
43
+ return None
44
+
45
+ # Convert to tensor
46
  image_tensor = transform(image).unsqueeze(0).to(device)
47
 
48
+ # Inference
49
  with torch.no_grad():
50
  output_tensor = generator(image_tensor)
51
 
52
+ # Prepare output
53
  output_image = output_tensor.squeeze(0).cpu()
54
+ output_image = output_image * 0.5 + 0.5 # Denormalization
55
  output_image = transforms.ToPILImage()(output_image)
56
 
57
  return output_image
58
 
 
59
  def load_images_from_folder(folder):
60
  images = []
61
+ if not os.path.exists(folder):
62
+ os.makedirs(folder)
63
+ return images
64
+
65
  for filename in os.listdir(folder):
66
  if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
67
  img_path = os.path.join(folder, filename)
68
+ try:
69
+ img = Image.open(img_path)
70
+ images.append((img, filename))
71
+ except Exception as e:
72
+ print(f"Error loading {filename}: {e}")
73
  return images
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def app():
76
  images = load_images_from_folder(photos_folder)
77
+ gallery_images = [img[0] for img in images] if images else []
78
 
79
+ with gr.Blocks() as demo:
80
  with gr.Row():
81
  with gr.Column():
82
+ input_image = gr.Image(label="Input Image", type="pil")
83
  clear_button = gr.Button("Clear")
84
 
85
  with gr.Column():
86
+ gallery = gr.Gallery(
87
+ label="Image Gallery",
88
+ value=gallery_images,
89
+ columns=3,
90
+ rows=2,
91
+ height="auto"
92
+ ).style(grid=3)
93
 
94
  with gr.Column():
95
+ output_image = gr.Image(label="Result Image", type="pil")
96
+
97
+ # Handle gallery selection
98
+ def on_select(evt: gr.SelectData):
99
+ if 0 <= evt.index < len(images):
100
+ return images[evt.index][0]
101
+ return None
102
+
103
+ gallery.select(
104
+ fn=on_select,
105
+ outputs=input_image
106
  )
107
+
108
+ # Process image when input changes
109
+ input_image.change(
110
  fn=process_image,
111
+ inputs=input_image,
112
+ outputs=output_image
113
  )
114
+
115
+ # Clear button functionality
116
  clear_button.click(
117
+ fn=lambda: None,
118
+ outputs=input_image
 
119
  )
120
 
121
  demo.launch()
122
 
 
123
  if __name__ == "__main__":
124
+ app()