yeq6x commited on
Commit
34fca51
·
1 Parent(s): a742794
Files changed (1) hide show
  1. app.py +29 -58
app.py CHANGED
@@ -7,73 +7,44 @@ from scripts.process_utils import initialize, process_image_as_base64
7
  from scripts.anime import init_model
8
  from scripts.generate_prompt import load_wd14_tagger_model
9
 
10
- import spaces
11
-
12
  # 初期化
13
  initialize(_use_local=False, use_gpu=True, use_dotenv=False)
14
  init_model(use_local=False)
15
  load_wd14_tagger_model()
16
 
17
- @spaces.GPU
18
  def process_image(input_image, mode, weight1, weight2):
19
  # 画像処理ロジック
20
  sotai_image, sketch_image = process_image_as_base64(input_image, mode, weight1, weight2)
21
-
22
- # Base64文字列をPIL Imageに変換
23
- sotai_pil = Image.open(io.BytesIO(base64.b64decode(sotai_image)))
24
- sketch_pil = Image.open(io.BytesIO(base64.b64decode(sketch_image)))
25
-
26
- return sotai_pil, sketch_pil
27
 
28
- # サンプル画像のパスリスト
29
- sample_images = [
30
- 'images/sample1.png',
31
- 'images/sample2.png',
32
- # 'images/sample4.png',
33
- # 'images/sample5.png',
34
- # 'images/sample6.png',
35
- # 'images/sample7.png',
36
- # 'images/sample8.png',
37
- # 'images/sample10.png',
38
- # 'images/sample11.png',
39
- # 'images/sample15.png',
40
- # 'images/sample16.png',
41
- # 'images/sample18.png',
42
- # 'images/sample19.png',
43
- # 'images/sample20.png',
44
- # 'images/sample21.png',
45
- ]
46
 
47
  # Gradio インターフェースの定義
48
- with gr.Blocks() as demo:
49
- gr.Markdown("# Image2Body Demo")
50
-
51
- with gr.Row():
52
- with gr.Column():
53
- input_image = gr.Image(type="pil", label="Input Image")
54
- mode = gr.Radio(["original", "refine"], label="Mode", value="original")
55
- with gr.Row():
56
- weight1 = gr.Slider(0, 2, value=0.6, step=0.05, label="Weight 1 (Sketch)")
57
- weight2 = gr.Slider(0, 1, value=0.05, step=0.025, label="Weight 2 (Body)")
58
- process_btn = gr.Button("Process")
59
-
60
- with gr.Column():
61
- sotai_output = gr.Image(type="pil", label="Sotai (Body) Image")
62
- sketch_output = gr.Image(type="pil", label="Sketch Image")
63
-
64
- gr.Examples(
65
- examples=[[path, "original", 0.6, 0.05] for path in sample_images],
66
- inputs=[input_image, mode, weight1, weight2],
67
- outputs=[sotai_output, sketch_output],
68
- fn=process_image,
69
- cache_examples=True,
70
- )
71
-
72
- process_btn.click(
73
- fn=process_image,
74
- inputs=[input_image, mode, weight1, weight2],
75
- outputs=[sotai_output, sketch_output]
76
- )
77
 
78
- # Spacesへのデプロイ設定
79
- demo.launch()
 
7
  from scripts.anime import init_model
8
  from scripts.generate_prompt import load_wd14_tagger_model
9
 
 
 
10
  # 初期化
11
  initialize(_use_local=False, use_gpu=True, use_dotenv=False)
12
  init_model(use_local=False)
13
  load_wd14_tagger_model()
14
 
 
15
  def process_image(input_image, mode, weight1, weight2):
16
  # 画像処理ロジック
17
  sotai_image, sketch_image = process_image_as_base64(input_image, mode, weight1, weight2)
18
+ return sotai_image, sketch_image
 
 
 
 
 
19
 
20
+ def gradio_process_image(input_image, mode, weight1, weight2):
21
+ # Gradio用の関数:PILイメージを受け取り、Base64文字列を返す
22
+ input_image_bytes = io.BytesIO()
23
+ input_image.save(input_image_bytes, format='PNG')
24
+ input_image_base64 = base64.b64encode(input_image_bytes.getvalue()).decode('utf-8')
25
+
26
+ sotai_base64, sketch_base64 = process_image(input_image_base64, mode, weight1, weight2)
27
+ return sotai_base64, sketch_base64
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Gradio インターフェースの定義
30
+ iface = gr.Interface(
31
+ fn=gradio_process_image,
32
+ inputs=[
33
+ gr.Image(type="pil", label="Input Image"),
34
+ gr.Radio(["original", "refine"], label="Mode", value="original"),
35
+ gr.Slider(0, 2, value=0.6, step=0.05, label="Weight 1 (Sketch)"),
36
+ gr.Slider(0, 1, value=0.05, step=0.025, label="Weight 2 (Body)")
37
+ ],
38
+ outputs=[
39
+ gr.Image(type="pil", label="Sotai (Body) Image"),
40
+ gr.Image(type="pil", label="Sketch Image")
41
+ ],
42
+ title="Image2Body API",
43
+ description="Upload an image and select processing options to generate body and sketch images."
44
+ )
45
+
46
+ # APIとして公開
47
+ app = gr.mount_gradio_app(app, iface, path="/")
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Hugging Face Spacesでデプロイする場合
50
+ iface.queue().launch()