yeq6x commited on
Commit
b7cfcd0
·
1 Parent(s): 0d01700
Files changed (2) hide show
  1. app.py +62 -327
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,334 +1,69 @@
1
- from flask import Flask, request, render_template, send_file, jsonify, send_from_directory, session, copy_current_request_context
2
- from flask_socketio import SocketIO, join_room, leave_room, close_room, rooms, disconnect
3
- from flask_cors import CORS
4
- from flask_limiter import Limiter
5
- from flask_limiter.util import get_remote_address
6
- import concurrent.futures
7
-
8
- import io
9
  import os
10
- import argparse
11
  from PIL import Image
12
- import torch
13
- import gc
14
- from peft import PeftModel
15
-
16
- import queue
17
- import threading
18
- import uuid
19
- import concurrent.futures
20
- from scripts.process_utils import *
21
-
22
- from gevent import pywsgi
23
- from geventwebsocket.handler import WebSocketHandler
24
-
25
- app = Flask(__name__)
26
- CORS(app)
27
- socketio = SocketIO(app, cors_allowed_origins="*")
28
- # クライアントIDとルーム情報を保存するグローバル辞書
29
- client_rooms = {}
30
-
31
- # レート制限の設定
32
- limiter = Limiter(
33
- get_remote_address,
34
- app=app,
35
- default_limits=["200 per day", "50 per hour"]
36
- )
37
-
38
- # タスクキューの作成とサイズ制限
39
- MAX_QUEUE_SIZE = 100
40
- task_queue = queue.Queue(maxsize=MAX_QUEUE_SIZE)
41
- active_tasks = {}
42
- task_futures = {}
43
-
44
- # ThreadPoolExecutorの作成
45
- executor = concurrent.futures.ThreadPoolExecutor(max_workers=int(os.environ.get('MAX_WORKERS', 4)))
46
-
47
- gpu_lock = threading.Lock()
48
-
49
- class Task:
50
- def __init__(self, task_id, mode, weight1, weight2, file_data, client_ip, client_id):
51
- self.task_id = task_id
52
- self.mode = mode
53
- self.weight1 = weight1
54
- self.weight2 = weight2
55
- self.file_data = file_data
56
- self.cancel_flag = False
57
- self.client_ip = client_ip
58
- self.is_processing = False
59
- self.client_id = client_id
60
-
61
- # キューの状態を通知
62
- def update_queue_status(message):
63
- socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message}, namespace='/demo')
64
-
65
- def process_task(task):
66
- try:
67
- task.is_processing = True
68
- # ファイルデータをPIL Imageに変換
69
- image = Image.open(io.BytesIO(task.file_data))
70
- image = ensure_rgb(image)
71
-
72
- # キャンセルチェック
73
- if task.cancel_flag:
74
- return
75
-
76
- # 画像処理ロジックを呼び出す
77
- # GPU処理部分
78
- with gpu_lock:
79
- sotai_image, sketch_image = process_image_as_base64(image, task.mode, task.weight1, task.weight2)
80
-
81
- # キャンセルチェック
82
- if task.cancel_flag:
83
- return
84
-
85
- # クライアントIDをリクエストヘッダーから取得(クライアント側で設定する必要があります)
86
- client_id = task.client_id
87
- if client_id and client_id in client_rooms:
88
- room = client_rooms[client_id]
89
- # ルームにメッセージをemit
90
- socketio.emit('task_complete', {
91
- 'task_id': task.task_id,
92
- 'sotai_image': sotai_image,
93
- 'sketch_image': sketch_image
94
- }, to=room, namespace='/demo')
95
-
96
- except Exception as e:
97
- print(f"Task error: {str(e)}")
98
- if not task.cancel_flag:
99
- client_id = task.client_id
100
- room = client_rooms[client_id]
101
- socketio.emit('task_error', {'task_id': task.task_id, 'error': str(e)}, to=room, namespace='/demo')
102
- finally:
103
- # タスク数をデクリメント
104
- client_ip = task.client_ip
105
- tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) - 1
106
- print(f'Task {task.task_id} completed')
107
- task.is_processing = False
108
- if task.task_id in active_tasks.keys():
109
- del active_tasks[task.task_id]
110
- if task.task_id in task_futures.keys():
111
- del task_futures[task.task_id]
112
-
113
- update_queue_status('Task completed or cancelled')
114
-
115
- def worker():
116
- while True:
117
- try:
118
- task = task_queue.get()
119
- if task.task_id in active_tasks.keys():
120
- future = executor.submit(process_task, task)
121
- task_futures[task.task_id] = future
122
- except Exception as e:
123
- print(f"Worker error: {str(e)}")
124
- finally:
125
- # Ensure the task is always removed from the queue
126
- task_queue.task_done()
127
-
128
- # ワーカースレッドの開始
129
- threading.Thread(target=worker, daemon=True).start()
130
-
131
- # グローバル変数を使用して接続数とタスク数を管理
132
- connected_clients = 0
133
- tasks_per_client = {}
134
- @socketio.on('connect', namespace='/demo')
135
- def handle_connect(auth):
136
- client_id = request.sid
137
- room = f"room_{client_id}" # クライアントごとに一意のルーム名を生成
138
- join_room(room)
139
- client_rooms[client_id] = room
140
- print(f"Client {client_id} connected and joined room {room}")
141
-
142
- global connected_clients
143
- connected_clients += 1
144
-
145
- @socketio.on('disconnect' )
146
- def handle_disconnect():
147
- client_id = request.sid
148
- if client_id in client_rooms:
149
- room = client_rooms[client_id]
150
- leave_room(room)
151
- del client_rooms[client_id]
152
- print(f"Client {client_id} disconnected and removed from room {room}")
153
-
154
- global connected_clients
155
- connected_clients -= 1
156
- # キャンセル処理:接続が切断された場合、そのクライアントに関連するタスクをキャンセル。ただし、1番目で処理中のタスクはキャンセルしない
157
- client_ip = get_remote_address()
158
- for task_id, task in active_tasks.items():
159
- if task.client_ip == client_ip and not task.is_processing:
160
- task.cancel_flag = True
161
- if task_id in task_futures:
162
- task_futures[task_id].cancel()
163
- del task_futures[task_id]
164
- del active_tasks[task_id]
165
- tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) - 1
166
-
167
- @app.route('/submit_task', methods=['POST'])
168
- @limiter.limit("10 per minute") # 1分間に10回までのリクエストに制限
169
- def submit_task():
170
- if task_queue.full():
171
- return jsonify({'error': 'Task queue is full. Please try again later.'}), 503
172
-
173
- # クライアントIPアドレスを取得
174
- client_ip = get_remote_address()
175
- # 同一IPからの同時タスク数を制限
176
- if tasks_per_client.get(client_ip, 0) >= 2:
177
- return jsonify({'error': 'Maximum number of concurrent tasks reached'}), 429
178
-
179
- task_id = str(uuid.uuid4())
180
- file = request.files['file']
181
- mode = request.form.get('mode', 'refine')
182
- weight1 = float(request.form.get('weight1', 0.4))
183
- weight2 = float(request.form.get('weight2', 0.3))
184
 
185
- # ファイルタイプの制限
186
- allowed_extensions = {'png', 'jpg', 'jpeg', 'gif'}
187
- if '.' not in file.filename or file.filename.rsplit('.', 1)[1].lower() not in allowed_extensions:
188
- return jsonify({'error': 'Invalid file type'}), 415
189
-
190
- # ファイルデータをバイト列として保存
191
- file_data = file.read()
192
 
193
- client_id = request.headers.get('X-Client-ID')
194
- task = Task(task_id, mode, weight1, weight2, file_data, client_ip, client_id)
195
- task_queue.put(task)
196
- active_tasks[task_id] = task
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- # 同一IPからのタスク数をインクリメント
199
- tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) + 1
200
-
201
- update_queue_status('Task submitted') # すべてに通知
202
-
203
- queue_size = task_queue.qsize()
204
- task_order = get_active_task_order(task_id)
205
- return jsonify({'task_id': task_id, 'task_order': task_order, 'queue_size': queue_size})
206
-
207
- @app.route('/cancel_task/<task_id>', methods=['POST'])
208
- def cancel_task(task_id):
209
- # クライアントIPアドレスを取得
210
- client_ip = get_remote_address()
211
-
212
- if task_id in active_tasks.keys():
213
- task = active_tasks[task_id]
214
- # タスクの所有者を確認(IPアドレスで簡易的に判断)
215
- if task.client_ip != client_ip:
216
- return jsonify({'error': 'Unauthorized to cancel this task'}), 403
217
- task.cancel_flag = True
218
- if task_id in task_futures.keys():
219
- task_futures[task_id].cancel()
220
- del task_futures[task_id]
221
- del active_tasks[task_id]
222
- # タスク数をデクリメント
223
- tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) - 1
224
- update_queue_status('Task cancelled')
225
- return jsonify({'message': 'Task cancellation requested'})
226
- else:
227
- for task in list(task_queue.queue):
228
- if task.task_id == task_id and task.client_ip == client_ip:
229
- task.cancel_flag = True
230
- # タスク数をデクリメント
231
- tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) - 1
232
- return jsonify({'message': 'Task cancellation requested for queued task'})
233
- return jsonify({'error': 'Task not found'}), 404
234
-
235
- @app.route('/task_status/<task_id>', methods=['GET'])
236
- def task_status(task_id):
237
- try:
238
- if task_id in active_tasks.keys():
239
- task = active_tasks[task_id]
240
- return jsonify({'task_id': task_id, 'is_processing': task.is_processing})
241
- else:
242
- return jsonify({'task_id': task_id, 'is_processing': False})
243
- except Exception as e:
244
- return jsonify({'error': str(e)}), 500
245
-
246
- def get_active_task_order(task_id):
247
- try:
248
- if task_id not in active_tasks.keys():
249
- return 0
250
- if active_tasks[task_id].is_processing:
251
- return 0
252
- processing_task_ids = [tid for tid, task in active_tasks.items() if task.is_processing]
253
- non_processing_task_ids = [tid for tid, task in active_tasks.items() if not task.is_processing]
254
- if len(processing_task_ids) == 0:
255
- task_order = 0
256
- else:
257
- task_order = non_processing_task_ids.index(task_id) + 1
258
- return task_order
259
- except Exception as e:
260
- print(f"Error getting task order: {str(e)}")
261
-
262
- # get_task_orderイベントハンドラー
263
- @app.route('/get_task_order/<task_id>', methods=['GET'])
264
- def handle_get_task_order(task_id):
265
- if task_id in active_tasks.keys():
266
- return jsonify({'task_order': get_active_task_order(task_id)})
267
- else:
268
- return jsonify({'task_order': 0})
269
-
270
- # Flaskルート
271
- # ルートパスのGETリクエストに対するハンドラ
272
- @app.route('/', methods=['GET'])
273
- def root():
274
- return render_template("index.html")
275
-
276
- # process_refined のエンドポイント
277
- @app.route('/process_refined', methods=['POST'])
278
- def process_refined():
279
- file = request.files['file']
280
- weight1 = float(request.form.get('weight1', 0.4))
281
- weight2 = float(request.form.get('weight2', 0.3))
282
-
283
- image = ensure_rgb(Image.open(file.stream))
284
- sotai_image, sketch_image = process_image_as_base64(image, "refine", weight1, weight2)
285
-
286
- return jsonify({
287
- 'sotai_image': sotai_image,
288
- 'sketch_image': sketch_image
289
- })
290
-
291
- @app.route('/process_original', methods=['POST'])
292
- def process_original():
293
- file = request.files['file']
294
-
295
- image = ensure_rgb(Image.open(file.stream))
296
- sotai_image, sketch_image = process_image_as_base64(image, "original")
297
-
298
- return jsonify({
299
- 'sotai_image': sotai_image,
300
- 'sketch_image': sketch_image
301
- })
302
-
303
- @app.route('/process_sketch', methods=['POST'])
304
- def process_sketch():
305
- file = request.files['file']
306
-
307
- image = ensure_rgb(Image.open(file.stream))
308
- sotai_image, sketch_image = process_image_as_base64(image, "sketch")
309
 
310
- return jsonify({
311
- 'sotai_image': sotai_image,
312
- 'sketch_image': sketch_image
313
- })
314
-
315
- # グローバルエラーハンドラー
316
- @app.errorhandler(Exception)
317
- def handle_exception(e):
318
- # ログにエラーを記録
319
- app.logger.error(f"Unhandled exception: {str(e)}")
320
- return jsonify({'error': 'An unexpected error occurred'}), 500
321
-
322
- if __name__ == '__main__':
323
- parser = argparse.ArgumentParser(description='Server options.')
324
- parser.add_argument('--use_local', action='store_true', help='Use local model')
325
- parser.add_argument('--use_gpu', action='store_true', help='Set to True to use GPU but if not available, it will use CPU')
326
- parser.add_argument('--use_dotenv', action='store_true', help='Use .env file for environment variables')
327
-
328
- args = parser.parse_args()
329
 
330
- initialize(args.use_local, args.use_gpu, args.use_dotenv)
331
-
332
- port = int(os.environ.get('PORT', 7860))
333
- server = pywsgi.WSGIServer(('0.0.0.0', port), app, handler_class=WebSocketHandler)
334
- server.serve_forever()
 
 
 
 
1
+ import gradio as gr
 
 
 
 
 
 
 
2
  import os
3
+ import io
4
  from PIL import Image
5
+ import base64
6
+ from scripts.process_utils import initialize, process_image_as_base64
7
+ from scripts.anime import init_model
8
+ from scripts.generate_prompt import load_wd14_tagger_model
9
+
10
+ # 初期化
11
+ initialize(use_local=False, use_gpu=True)
12
+ init_model(use_local=False)
13
+ load_wd14_tagger_model()
14
+
15
+ def process_image(input_image, mode, weight1, weight2):
16
+ # 画像処理ロジック
17
+ sotai_image, sketch_image = process_image_as_base64(input_image, mode, weight1, weight2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Base64文字列をPIL Imageに変換
20
+ sotai_pil = Image.open(io.BytesIO(base64.b64decode(sotai_image)))
21
+ sketch_pil = Image.open(io.BytesIO(base64.b64decode(sketch_image)))
 
 
 
 
22
 
23
+ return sotai_pil, sketch_pil
24
+
25
+ def gradio_process_image(input_image, mode, weight1, weight2):
26
+ sotai_image, sketch_image = process_image(input_image, mode, weight1, weight2)
27
+ return sotai_image, sketch_image
28
+
29
+ # サンプル画像のパスリスト
30
+ sample_images = [
31
+ 'images/sample1.png',
32
+ 'images/sample2.png',
33
+ 'images/sample4.png',
34
+ # ... 他のサンプル画像
35
+ ]
36
+
37
+ # Gradio インターフェースの定義
38
+ with gr.Blocks() as demo:
39
+ gr.Markdown("# Image2Body Test")
40
 
41
+ with gr.Row():
42
+ with gr.Column():
43
+ input_image = gr.Image(type="pil", label="Input Image")
44
+ mode = gr.Radio(["original", "refine"], label="Mode", value="original")
45
+ with gr.Row():
46
+ weight1 = gr.Slider(0, 2, value=0.6, step=0.05, label="Weight 1 (Sketch)")
47
+ weight2 = gr.Slider(0, 1, value=0.05, step=0.025, label="Weight 2 (Body)")
48
+ process_btn = gr.Button("Process")
49
+
50
+ with gr.Column():
51
+ sotai_output = gr.Image(type="pil", label="Sotai (Body) Image")
52
+ sketch_output = gr.Image(type="pil", label="Sketch Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ gr.Examples(
55
+ examples=sample_images,
56
+ inputs=input_image,
57
+ outputs=[sotai_output, sketch_output],
58
+ fn=gradio_process_image,
59
+ cache_examples=True,
60
+ )
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ process_btn.click(
63
+ fn=gradio_process_image,
64
+ inputs=[input_image, mode, weight1, weight2],
65
+ outputs=[sotai_output, sketch_output]
66
+ )
67
+
68
+ # Spacesへのデプロイ設定
69
+ demo.launch()
requirements.txt CHANGED
@@ -2,6 +2,7 @@
2
  torch==2.2.0
3
  torchvision==0.17.0
4
  torchaudio==2.2.0
 
5
  diffusers==0.27.0 # pth file cannot be loaded in the latest version
6
  Flask==3.0.3
7
  Flask-Cors==4.0.0
 
2
  torch==2.2.0
3
  torchvision==0.17.0
4
  torchaudio==2.2.0
5
+ transformers
6
  diffusers==0.27.0 # pth file cannot be loaded in the latest version
7
  Flask==3.0.3
8
  Flask-Cors==4.0.0