fantos commited on
Commit
d0dbba0
·
verified ·
1 Parent(s): da7073d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -301
app.py CHANGED
@@ -5,87 +5,25 @@ import gradio as gr
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
  import os
8
- import io
9
- import base64
10
- import json
11
- from datetime import datetime
12
- import torch.nn.functional as F
13
 
14
- # Force CPU mode for Zero GPU environment
15
- device = torch.device('cpu')
16
- torch.set_num_threads(4) # Optimize CPU performance
17
-
18
- # Optimize memory usage
19
- torch.backends.cudnn.benchmark = False
20
- torch.backends.cudnn.deterministic = True
21
-
22
- # Reduce memory usage for history
23
- MAX_HISTORY_ENTRIES = 5
24
-
25
- # Style presets
26
- STYLE_PRESETS = {
27
- "Sketch": {"line_thickness": 1.0, "contrast": 1.2, "brightness": 1.0},
28
- "Bold": {"line_thickness": 1.5, "contrast": 1.4, "brightness": 0.8},
29
- "Light": {"line_thickness": 0.8, "contrast": 0.9, "brightness": 1.2},
30
- "High Contrast": {"line_thickness": 1.2, "contrast": 1.6, "brightness": 0.7},
31
- }
32
-
33
- # History management
34
- class HistoryManager:
35
- def __init__(self, max_entries=10):
36
- self.max_entries = max_entries
37
- self.history_file = "processing_history.json"
38
- self.history = self.load_history()
39
-
40
- def load_history(self):
41
- try:
42
- if os.path.exists(self.history_file):
43
- with open(self.history_file, 'r') as f:
44
- return json.load(f)
45
- return []
46
- except Exception:
47
- return []
48
-
49
- def save_history(self):
50
- try:
51
- with open(self.history_file, 'w') as f:
52
- json.dump(self.history[-self.max_entries:], f)
53
- except Exception as e:
54
- print(f"Error saving history: {e}")
55
-
56
- def add_entry(self, input_path, settings):
57
- entry = {
58
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
59
- "input_file": os.path.basename(input_path),
60
- "settings": settings
61
- }
62
- self.history.append(entry)
63
- if len(self.history) > self.max_entries:
64
- self.history.pop(0)
65
- self.save_history()
66
-
67
- def get_latest_settings(self):
68
- if self.history:
69
- return self.history[-1]["settings"]
70
- return None
71
-
72
- # Initialize history manager with reduced entries
73
- history_manager = HistoryManager(max_entries=MAX_HISTORY_ENTRIES)
74
 
75
  norm_layer = nn.InstanceNorm2d
76
 
77
  class ResidualBlock(nn.Module):
78
  def __init__(self, in_features):
79
  super(ResidualBlock, self).__init__()
80
-
81
- conv_block = [ nn.ReflectionPad2d(1),
82
- nn.Conv2d(in_features, in_features, 3),
83
- norm_layer(in_features),
84
- nn.ReLU(inplace=True),
85
- nn.ReflectionPad2d(1),
86
- nn.Conv2d(in_features, in_features, 3),
87
- norm_layer(in_features) ]
88
-
89
  self.conv_block = nn.Sequential(*conv_block)
90
 
91
  def forward(self, x):
@@ -96,10 +34,12 @@ class Generator(nn.Module):
96
  super(Generator, self).__init__()
97
 
98
  # Initial convolution block
99
- model0 = [ nn.ReflectionPad2d(3),
100
- nn.Conv2d(input_nc, 64, 7),
101
- norm_layer(64),
102
- nn.ReLU(inplace=True) ]
 
 
103
  self.model0 = nn.Sequential(*model0)
104
 
105
  # Downsampling
@@ -107,9 +47,11 @@ class Generator(nn.Module):
107
  in_features = 64
108
  out_features = in_features*2
109
  for _ in range(2):
110
- model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
111
- norm_layer(out_features),
112
- nn.ReLU(inplace=True) ]
 
 
113
  in_features = out_features
114
  out_features = in_features*2
115
  self.model1 = nn.Sequential(*model1)
@@ -124,19 +66,22 @@ class Generator(nn.Module):
124
  model3 = []
125
  out_features = in_features//2
126
  for _ in range(2):
127
- model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
128
- norm_layer(out_features),
129
- nn.ReLU(inplace=True) ]
 
 
130
  in_features = out_features
131
  out_features = in_features//2
132
  self.model3 = nn.Sequential(*model3)
133
 
134
  # Output layer
135
- model4 = [ nn.ReflectionPad2d(3),
136
- nn.Conv2d(64, output_nc, 7)]
 
 
137
  if sigmoid:
138
  model4 += [nn.Sigmoid()]
139
-
140
  self.model4 = nn.Sequential(*model4)
141
 
142
  def forward(self, x):
@@ -147,275 +92,102 @@ class Generator(nn.Module):
147
  out = self.model4(out)
148
  return out
149
 
150
- # Initialize models
151
  def load_models():
152
  try:
153
  print("Initializing models in CPU mode...")
154
  model1 = Generator(3, 1, 3)
155
  model2 = Generator(3, 1, 3)
156
 
 
157
  model1.load_state_dict(torch.load('model.pth', map_location='cpu'))
158
  model2.load_state_dict(torch.load('model2.pth', map_location='cpu'))
159
 
160
  model1.eval()
161
  model2.eval()
162
- torch.set_grad_enabled(False)
163
 
164
- print("Models loaded successfully in CPU mode")
165
  return model1, model2
166
  except Exception as e:
167
- error_msg = f"Error loading models: {str(e)}"
168
- print(error_msg)
169
- raise gr.Error("Failed to initialize models. Please check the model files and system configuration.")
170
 
171
- # Load models
172
  try:
173
  print("Starting model initialization...")
174
  model1, model2 = load_models()
175
  print("Model initialization completed")
176
  except Exception as e:
177
- print(f"Critical error during model initialization: {str(e)}")
178
- raise gr.Error("Failed to start the application due to model initialization error.")
179
-
180
- def apply_preset(preset_name):
181
- """Apply a style preset and return the settings"""
182
- if preset_name in STYLE_PRESETS:
183
- return (
184
- STYLE_PRESETS[preset_name]["line_thickness"],
185
- STYLE_PRESETS[preset_name]["contrast"],
186
- STYLE_PRESETS[preset_name]["brightness"],
187
- True # Enable enhancement for presets
188
- )
189
- return (1.0, 1.0, 1.0, False)
190
 
191
- def enhance_lines(img, contrast=1.0, brightness=1.0):
192
- """Enhance line drawing with contrast and brightness adjustments"""
193
- enhanced = np.array(img)
194
- enhanced = enhanced * contrast
195
- enhanced = np.clip(enhanced + brightness, 0, 1)
196
- return Image.fromarray((enhanced * 255).astype(np.uint8))
197
-
198
- def predict(input_img, version, preset_name, line_thickness=1.0, contrast=1.0,
199
- brightness=1.0, enable_enhancement=False, output_size="Original"):
200
  try:
201
- # Apply preset if selected
202
- if preset_name != "Custom":
203
- line_thickness, contrast, brightness, enable_enhancement = apply_preset(preset_name)
204
-
205
- # Open and process input image
206
  original_img = Image.open(input_img)
207
  original_size = original_img.size
208
-
209
- # Adjust output size
210
- if output_size != "Original":
211
- width, height = map(int, output_size.split("x"))
212
- target_size = (width, height)
213
- else:
214
- target_size = original_size
215
-
216
- # Transform pipeline
217
  transform = transforms.Compose([
218
  transforms.Resize(256, Image.BICUBIC),
219
  transforms.ToTensor(),
220
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
221
  ])
222
 
223
- input_tensor = transform(original_img).unsqueeze(0).to(device)
224
 
225
- # Process through selected model
226
  with torch.no_grad():
227
  if version == 'Simple Lines':
228
  output = model2(input_tensor)
229
  else:
230
  output = model1(input_tensor)
231
 
232
- # Apply line thickness adjustment
233
  output = output * line_thickness
234
 
235
- # Convert to image
236
- output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
237
-
238
- # Apply enhancements if enabled
239
- if enable_enhancement:
240
- output_img = enhance_lines(output_img, contrast, brightness)
241
-
242
- # Resize to target size
243
- output_img = output_img.resize(target_size, Image.BICUBIC)
244
-
245
- # Save to history
246
- settings = {
247
- "version": version,
248
- "preset": preset_name,
249
- "line_thickness": line_thickness,
250
- "contrast": contrast,
251
- "brightness": brightness,
252
- "enable_enhancement": enable_enhancement,
253
- "output_size": output_size
254
- }
255
- history_manager.add_entry(input_img, settings)
256
 
257
  return output_img
258
 
259
  except Exception as e:
260
- raise gr.Error(f"Error processing image: {str(e)}")
261
-
262
- # Custom CSS
263
- custom_css = """
264
- .gradio-container {
265
- font-family: 'Helvetica Neue', Arial, sans-serif;
266
- max-width: 1200px !important;
267
- margin: auto;
268
- }
269
- .gr-button {
270
- border-radius: 8px;
271
- background: linear-gradient(45deg, #3498db, #2980b9);
272
- border: none;
273
- color: white;
274
- transition: all 0.3s ease;
275
- }
276
- .gr-button:hover {
277
- background: linear-gradient(45deg, #2980b9, #3498db);
278
- transform: translateY(-2px);
279
- box-shadow: 0 4px 12px rgba(0,0,0,0.15);
280
- }
281
- .gr-button.secondary {
282
- background: linear-gradient(45deg, #95a5a6, #7f8c8d);
283
- }
284
- .gr-input {
285
- border-radius: 8px;
286
- border: 2px solid #3498db;
287
- transition: all 0.3s ease;
288
- }
289
- .gr-input:focus {
290
- border-color: #2980b9;
291
- box-shadow: 0 0 0 2px rgba(41,128,185,0.2);
292
- }
293
- .gr-form {
294
- border-radius: 12px;
295
- box-shadow: 0 4px 12px rgba(0,0,0,0.1);
296
- padding: 20px;
297
- }
298
- .gr-header {
299
- text-align: center;
300
- margin-bottom: 2em;
301
- }
302
- """
303
 
304
- # Create Gradio interface
305
- with gr.Blocks(css=custom_css) as iface:
306
- with gr.Row(elem_classes="gr-header"):
307
- gr.Markdown("# 🎨 Advanced Line Drawing Generator")
308
- gr.Markdown("Transform your images into beautiful line drawings with advanced controls")
309
 
310
  with gr.Row():
311
- with gr.Column(scale=1):
312
  input_image = gr.Image(type="filepath", label="Upload Image")
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- with gr.Row():
315
- version = gr.Radio(
316
- choices=['Complex Lines', 'Simple Lines'],
317
- value='Simple Lines',
318
- label="Drawing Style"
319
- )
320
- preset_selector = gr.Dropdown(
321
- choices=["Custom"] + list(STYLE_PRESETS.keys()),
322
- value="Custom",
323
- label="Style Preset"
324
- )
325
-
326
- with gr.Accordion("Advanced Settings", open=False):
327
- output_size = gr.Dropdown(
328
- choices=["Original", "512x512", "1024x1024", "2048x2048"],
329
- value="Original",
330
- label="Output Size"
331
- )
332
-
333
- line_thickness = gr.Slider(
334
- minimum=0.1,
335
- maximum=2.0,
336
- value=1.0,
337
- step=0.1,
338
- label="Line Thickness"
339
- )
340
-
341
- enable_enhancement = gr.Checkbox(
342
- label="Enable Enhancement",
343
- value=False
344
- )
345
-
346
- with gr.Group(visible=False) as enhancement_controls:
347
- contrast = gr.Slider(
348
- minimum=0.5,
349
- maximum=2.0,
350
- value=1.0,
351
- step=0.1,
352
- label="Contrast"
353
- )
354
- brightness = gr.Slider(
355
- minimum=0.5,
356
- maximum=1.5,
357
- value=1.0,
358
- step=0.1,
359
- label="Brightness"
360
- )
361
-
362
- with gr.Column(scale=1):
363
- output_image = gr.Image(type="pil", label="Generated Line Drawing")
364
- with gr.Row():
365
- generate_btn = gr.Button("Generate", variant="primary", size="lg")
366
- clear_btn = gr.Button("Clear", variant="secondary", size="lg")
367
 
368
- # Event handlers
369
- enable_enhancement.change(
370
- fn=lambda x: gr.Group(visible=x),
371
- inputs=[enable_enhancement],
372
- outputs=[enhancement_controls]
373
- )
374
-
375
- preset_selector.change(
376
- fn=apply_preset,
377
- inputs=[preset_selector],
378
- outputs=[line_thickness, contrast, brightness, enable_enhancement]
379
- )
380
 
 
381
  generate_btn.click(
382
- fn=predict,
383
- inputs=[
384
- input_image,
385
- version,
386
- preset_selector,
387
- line_thickness,
388
- contrast,
389
- brightness,
390
- enable_enhancement,
391
- output_size
392
- ],
393
  outputs=output_image
394
  )
395
-
396
- clear_btn.click(
397
- fn=lambda: (None, "Simple Lines", "Custom", 1.0, 1.0, 1.0, False, "Original"),
398
- inputs=[],
399
- outputs=[
400
- input_image,
401
- version,
402
- preset_selector,
403
- line_thickness,
404
- contrast,
405
- brightness,
406
- enable_enhancement,
407
- output_size
408
- ]
409
- )
410
 
411
- # Launch the interface with optimized settings
412
  iface.launch(
413
  server_name="0.0.0.0",
414
  server_port=7860,
415
- share=False,
416
- debug=False,
417
- show_error=True,
418
- max_threads=4,
419
- ssr=False, # Disable SSR to prevent Node.js issues
420
- cache_examples=False, # Disable example caching to save memory
421
  )
 
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
  import os
 
 
 
 
 
8
 
9
+ # CPU 전용 설정
10
+ torch.set_num_threads(4) # CPU 스레드 수 제한
11
+ torch.set_grad_enabled(False) # 추론 모드만 사용
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  norm_layer = nn.InstanceNorm2d
14
 
15
  class ResidualBlock(nn.Module):
16
  def __init__(self, in_features):
17
  super(ResidualBlock, self).__init__()
18
+ conv_block = [
19
+ nn.ReflectionPad2d(1),
20
+ nn.Conv2d(in_features, in_features, 3),
21
+ norm_layer(in_features),
22
+ nn.ReLU(inplace=True),
23
+ nn.ReflectionPad2d(1),
24
+ nn.Conv2d(in_features, in_features, 3),
25
+ norm_layer(in_features)
26
+ ]
27
  self.conv_block = nn.Sequential(*conv_block)
28
 
29
  def forward(self, x):
 
34
  super(Generator, self).__init__()
35
 
36
  # Initial convolution block
37
+ model0 = [
38
+ nn.ReflectionPad2d(3),
39
+ nn.Conv2d(input_nc, 64, 7),
40
+ norm_layer(64),
41
+ nn.ReLU(inplace=True)
42
+ ]
43
  self.model0 = nn.Sequential(*model0)
44
 
45
  # Downsampling
 
47
  in_features = 64
48
  out_features = in_features*2
49
  for _ in range(2):
50
+ model1 += [
51
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
52
+ norm_layer(out_features),
53
+ nn.ReLU(inplace=True)
54
+ ]
55
  in_features = out_features
56
  out_features = in_features*2
57
  self.model1 = nn.Sequential(*model1)
 
66
  model3 = []
67
  out_features = in_features//2
68
  for _ in range(2):
69
+ model3 += [
70
+ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
71
+ norm_layer(out_features),
72
+ nn.ReLU(inplace=True)
73
+ ]
74
  in_features = out_features
75
  out_features = in_features//2
76
  self.model3 = nn.Sequential(*model3)
77
 
78
  # Output layer
79
+ model4 = [
80
+ nn.ReflectionPad2d(3),
81
+ nn.Conv2d(64, output_nc, 7)
82
+ ]
83
  if sigmoid:
84
  model4 += [nn.Sigmoid()]
 
85
  self.model4 = nn.Sequential(*model4)
86
 
87
  def forward(self, x):
 
92
  out = self.model4(out)
93
  return out
94
 
95
+ # CPU 전용 모델 로드
96
  def load_models():
97
  try:
98
  print("Initializing models in CPU mode...")
99
  model1 = Generator(3, 1, 3)
100
  model2 = Generator(3, 1, 3)
101
 
102
+ # Load models in CPU mode
103
  model1.load_state_dict(torch.load('model.pth', map_location='cpu'))
104
  model2.load_state_dict(torch.load('model2.pth', map_location='cpu'))
105
 
106
  model1.eval()
107
  model2.eval()
 
108
 
109
+ print("Models loaded successfully")
110
  return model1, model2
111
  except Exception as e:
112
+ print(f"Error loading models: {str(e)}")
113
+ raise gr.Error("Failed to initialize models. Please check model files.")
 
114
 
 
115
  try:
116
  print("Starting model initialization...")
117
  model1, model2 = load_models()
118
  print("Model initialization completed")
119
  except Exception as e:
120
+ print(f"Critical error: {str(e)}")
121
+ raise gr.Error("Failed to start the application")
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ def process_image(input_img, version, line_thickness=1.0):
 
 
 
 
 
 
 
 
124
  try:
125
+ # 이미지 로드 전처리
 
 
 
 
126
  original_img = Image.open(input_img)
127
  original_size = original_img.size
128
+
 
 
 
 
 
 
 
 
129
  transform = transforms.Compose([
130
  transforms.Resize(256, Image.BICUBIC),
131
  transforms.ToTensor(),
132
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
133
  ])
134
 
135
+ input_tensor = transform(original_img).unsqueeze(0)
136
 
137
+ # 모델 처리
138
  with torch.no_grad():
139
  if version == 'Simple Lines':
140
  output = model2(input_tensor)
141
  else:
142
  output = model1(input_tensor)
143
 
 
144
  output = output * line_thickness
145
 
146
+ # 결과 이미지 생성
147
+ output_img = transforms.ToPILImage()(output.squeeze().clamp(0, 1))
148
+ output_img = output_img.resize(original_size, Image.BICUBIC)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  return output_img
151
 
152
  except Exception as e:
153
+ raise gr.Error(f"이미지 처리 에러: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # Simple UI
156
+ with gr.Blocks() as iface:
157
+ gr.Markdown("# ✨ Magic Drawings")
158
+ gr.Markdown("Transform your photos into magical line art with AI")
 
159
 
160
  with gr.Row():
161
+ with gr.Column():
162
  input_image = gr.Image(type="filepath", label="Upload Image")
163
+ version = gr.Radio(
164
+ choices=['Complex Lines', 'Simple Lines'],
165
+ value='Simple Lines',
166
+ label="Art Style"
167
+ )
168
+ line_thickness = gr.Slider(
169
+ minimum=0.1,
170
+ maximum=2.0,
171
+ value=1.0,
172
+ step=0.1,
173
+ label="Line Thickness"
174
+ )
175
 
176
+ with gr.Column():
177
+ output_image = gr.Image(type="pil", label="Generated Art")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ generate_btn = gr.Button("Generate Magic", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ # Event handlers
182
  generate_btn.click(
183
+ fn=process_image,
184
+ inputs=[input_image, version, line_thickness],
 
 
 
 
 
 
 
 
 
185
  outputs=output_image
186
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ # 실행
189
  iface.launch(
190
  server_name="0.0.0.0",
191
  server_port=7860,
192
+ share=False
 
 
 
 
 
193
  )