File size: 5,875 Bytes
3b6ad82
 
 
 
 
 
2cdbfb7
 
d0dbba0
 
 
3b6ad82
cdc202d
 
 
 
 
d0dbba0
 
 
 
 
 
 
 
 
cdc202d
 
 
 
 
 
 
 
 
 
d0dbba0
 
 
 
 
 
cdc202d
 
 
 
 
 
 
d0dbba0
 
 
 
 
cdc202d
 
 
 
 
 
 
 
 
 
 
 
 
 
d0dbba0
 
 
 
 
cdc202d
 
 
 
 
d0dbba0
 
 
 
cdc202d
 
 
 
 
 
 
 
 
 
 
 
d0dbba0
cdc202d
 
 
 
 
 
d0dbba0
cdc202d
 
 
 
 
 
d0dbba0
cdc202d
 
d0dbba0
 
cdc202d
 
 
 
 
 
d0dbba0
 
3b6ad82
d0dbba0
2cdbfb7
d0dbba0
2cdbfb7
 
d0dbba0
2cdbfb7
 
 
 
 
 
d0dbba0
2cdbfb7
d0dbba0
2cdbfb7
 
 
 
 
 
 
 
d0dbba0
 
 
2cdbfb7
 
 
 
d0dbba0
2cdbfb7
d0dbba0
 
 
 
2cdbfb7
 
d0dbba0
2cdbfb7
d0dbba0
 
 
 
 
 
 
 
 
 
 
 
9c857dc
d0dbba0
 
2cdbfb7
d0dbba0
2cdbfb7
d0dbba0
2cdbfb7
d0dbba0
 
2cdbfb7
 
3b6ad82
d0dbba0
8a9898f
9c857dc
 
d0dbba0
8a9898f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
import os

# CPU 전용 설정
torch.set_num_threads(4)  # CPU 스레드 수 제한
torch.set_grad_enabled(False)  # 추론 모드만 사용

norm_layer = nn.InstanceNorm2d

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        conv_block = [
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            norm_layer(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            norm_layer(in_features)
        ]
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
        super(Generator, self).__init__()
        
        # Initial convolution block
        model0 = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, 7),
            norm_layer(64),
            nn.ReLU(inplace=True)
        ]
        self.model0 = nn.Sequential(*model0)

        # Downsampling
        model1 = []
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model1 += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                norm_layer(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features*2
        self.model1 = nn.Sequential(*model1)

        # Residual blocks
        model2 = []
        for _ in range(n_residual_blocks):
            model2 += [ResidualBlock(in_features)]
        self.model2 = nn.Sequential(*model2)

        # Upsampling
        model3 = []
        out_features = in_features//2
        for _ in range(2):
            model3 += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                norm_layer(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features//2
        self.model3 = nn.Sequential(*model3)

        # Output layer
        model4 = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, 7)
        ]
        if sigmoid:
            model4 += [nn.Sigmoid()]
        self.model4 = nn.Sequential(*model4)

    def forward(self, x):
        out = self.model0(x)
        out = self.model1(out)
        out = self.model2(out)
        out = self.model3(out)
        out = self.model4(out)
        return out

# CPU 전용 모델 로드
def load_models():
    try:
        print("Initializing models in CPU mode...")
        model1 = Generator(3, 1, 3)
        model2 = Generator(3, 1, 3)
        
        # Load models in CPU mode
        model1.load_state_dict(torch.load('model.pth', map_location='cpu'))
        model2.load_state_dict(torch.load('model2.pth', map_location='cpu'))
        
        model1.eval()
        model2.eval()
        
        print("Models loaded successfully")
        return model1, model2
    except Exception as e:
        print(f"Error loading models: {str(e)}")
        raise gr.Error("Failed to initialize models. Please check model files.")

try:
    print("Starting model initialization...")
    model1, model2 = load_models()
    print("Model initialization completed")
except Exception as e:
    print(f"Critical error: {str(e)}")
    raise gr.Error("Failed to start the application")

def process_image(input_img, version, line_thickness=1.0):
    try:
        # 이미지 로드 및 전처리
        original_img = Image.open(input_img)
        original_size = original_img.size
        
        transform = transforms.Compose([
            transforms.Resize(256, Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        input_tensor = transform(original_img).unsqueeze(0)
        
        # 모델 처리
        with torch.no_grad():
            if version == 'Simple Lines':
                output = model2(input_tensor)
            else:
                output = model1(input_tensor)
            
            output = output * line_thickness
        
        # 결과 이미지 생성
        output_img = transforms.ToPILImage()(output.squeeze().clamp(0, 1))
        output_img = output_img.resize(original_size, Image.BICUBIC)
        
        return output_img
        
    except Exception as e:
        raise gr.Error(f"이미지 처리 에러: {str(e)}")

# Simple UI
with gr.Blocks() as iface:
    gr.Markdown("# ✨ Magic Drawings")
    gr.Markdown("Transform your photos into magical line art with AI")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="filepath", label="Upload Image")
            version = gr.Radio(
                choices=['Complex Lines', 'Simple Lines'],
                value='Simple Lines',
                label="Art Style"
            )
            line_thickness = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=1.0,
                step=0.1,
                label="Line Thickness"
            )
            
        with gr.Column():
            output_image = gr.Image(type="pil", label="Generated Art")
    
    generate_btn = gr.Button("Generate Magic", variant="primary")
    
    # Event handlers
    generate_btn.click(
        fn=process_image,
        inputs=[input_image, version, line_thickness],
        outputs=output_image
    )

# 실행
iface.launch(
    server_name="0.0.0.0",
    server_port=7860,
    share=False
)