Spaces:
Runtime error
Runtime error
bug naosita
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from flask import Flask, request, render_template, send_file, jsonify, send_from_directory, session
|
2 |
from flask_socketio import SocketIO, join_room, leave_room, close_room, rooms, disconnect
|
3 |
from flask_cors import CORS
|
4 |
from flask_limiter import Limiter
|
@@ -23,9 +23,10 @@ from gevent import pywsgi
|
|
23 |
from geventwebsocket.handler import WebSocketHandler
|
24 |
|
25 |
app = Flask(__name__)
|
26 |
-
app.secret_key = 'user'
|
27 |
CORS(app)
|
28 |
socketio = SocketIO(app, cors_allowed_origins="*")
|
|
|
|
|
29 |
|
30 |
# レート制限の設定
|
31 |
limiter = Limiter(
|
@@ -41,10 +42,12 @@ active_tasks = {}
|
|
41 |
task_futures = {}
|
42 |
|
43 |
# ThreadPoolExecutorの作成
|
44 |
-
executor = concurrent.futures.ThreadPoolExecutor(max_workers=
|
|
|
|
|
45 |
|
46 |
class Task:
|
47 |
-
def __init__(self, task_id, mode, weight1, weight2, file_data, client_ip):
|
48 |
self.task_id = task_id
|
49 |
self.mode = mode
|
50 |
self.weight1 = weight1
|
@@ -53,16 +56,14 @@ class Task:
|
|
53 |
self.cancel_flag = False
|
54 |
self.client_ip = client_ip
|
55 |
self.is_processing = False
|
|
|
56 |
|
57 |
# キューの状態を通知
|
58 |
def update_queue_status(message):
|
59 |
-
socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message})
|
60 |
|
61 |
def process_task(task):
|
62 |
try:
|
63 |
-
update_queue_status("sid")
|
64 |
-
update_queue_status(str("sid" in session))
|
65 |
-
client_id = session.get("sid")
|
66 |
task.is_processing = True
|
67 |
# ファイルデータをPIL Imageに変換
|
68 |
image = Image.open(io.BytesIO(task.file_data))
|
@@ -73,40 +74,49 @@ def process_task(task):
|
|
73 |
return
|
74 |
|
75 |
# 画像処理ロジックを呼び出す
|
76 |
-
|
|
|
|
|
77 |
|
78 |
# キャンセルチェック
|
79 |
if task.cancel_flag:
|
80 |
return
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
except Exception as e:
|
88 |
print(f"Task error: {str(e)}")
|
89 |
if not task.cancel_flag:
|
90 |
-
|
|
|
|
|
91 |
finally:
|
92 |
-
task.is_processing = False
|
93 |
-
if task.task_id in active_tasks:
|
94 |
-
del active_tasks[task.task_id]
|
95 |
-
if task.task_id in task_futures:
|
96 |
-
del task_futures[task.task_id]
|
97 |
-
|
98 |
# タスク数をデクリメント
|
99 |
client_ip = task.client_ip
|
100 |
tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) - 1
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
update_queue_status('Task completed or cancelled')
|
103 |
|
104 |
def worker():
|
105 |
while True:
|
106 |
try:
|
107 |
task = task_queue.get()
|
108 |
-
if task.task_id in active_tasks:
|
109 |
-
print(f"Processing task {task.task_id}")
|
110 |
future = executor.submit(process_task, task)
|
111 |
task_futures[task.task_id] = future
|
112 |
except Exception as e:
|
@@ -121,23 +131,31 @@ threading.Thread(target=worker, daemon=True).start()
|
|
121 |
# グローバル変数を使用して接続数とタスク数を管理
|
122 |
connected_clients = 0
|
123 |
tasks_per_client = {}
|
124 |
-
@socketio.on('connect', namespace='/')
|
125 |
def handle_connect(auth):
|
126 |
-
|
127 |
-
|
128 |
-
join_room(
|
|
|
|
|
|
|
129 |
global connected_clients
|
130 |
connected_clients += 1
|
131 |
|
132 |
-
@socketio.on('disconnect'
|
133 |
def handle_disconnect():
|
134 |
-
client_id = request.sid
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
136 |
global connected_clients
|
137 |
connected_clients -= 1
|
138 |
# キャンセル処理:接続が切断された場合、そのクライアントに関連するタスクをキャンセル。ただし、1番目で処理中のタスクはキャンセルしない
|
139 |
client_ip = get_remote_address()
|
140 |
-
for task_id, task in
|
141 |
if task.client_ip == client_ip and not task.is_processing:
|
142 |
task.cancel_flag = True
|
143 |
if task_id in task_futures:
|
@@ -154,8 +172,6 @@ def submit_task():
|
|
154 |
|
155 |
# クライアントIPアドレスを取得
|
156 |
client_ip = get_remote_address()
|
157 |
-
# client_id = session.get("sid")
|
158 |
-
|
159 |
# 同一IPからの同時タスク数を制限
|
160 |
if tasks_per_client.get(client_ip, 0) >= 2:
|
161 |
return jsonify({'error': 'Maximum number of concurrent tasks reached'}), 429
|
@@ -174,30 +190,32 @@ def submit_task():
|
|
174 |
# ファイルデータをバイト列として保存
|
175 |
file_data = file.read()
|
176 |
|
177 |
-
|
|
|
178 |
task_queue.put(task)
|
179 |
active_tasks[task_id] = task
|
180 |
|
181 |
# 同一IPからのタスク数をインクリメント
|
182 |
tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) + 1
|
183 |
|
184 |
-
update_queue_status('Task submitted')
|
185 |
|
186 |
queue_size = task_queue.qsize()
|
187 |
-
|
|
|
188 |
|
189 |
@app.route('/cancel_task/<task_id>', methods=['POST'])
|
190 |
def cancel_task(task_id):
|
191 |
# クライアントIPアドレスを取得
|
192 |
client_ip = get_remote_address()
|
193 |
|
194 |
-
if task_id in active_tasks:
|
195 |
task = active_tasks[task_id]
|
196 |
# タスクの所有者を確認(IPアドレスで簡易的に判断)
|
197 |
if task.client_ip != client_ip:
|
198 |
return jsonify({'error': 'Unauthorized to cancel this task'}), 403
|
199 |
task.cancel_flag = True
|
200 |
-
if task_id in task_futures:
|
201 |
task_futures[task_id].cancel()
|
202 |
del task_futures[task_id]
|
203 |
del active_tasks[task_id]
|
@@ -216,21 +234,27 @@ def cancel_task(task_id):
|
|
216 |
|
217 |
@app.route('/task_status/<task_id>', methods=['GET'])
|
218 |
def task_status(task_id):
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
223 |
|
224 |
def get_active_task_order(task_id):
|
225 |
-
|
226 |
-
return
|
227 |
|
228 |
# get_task_orderイベントハンドラー
|
229 |
@app.route('/get_task_order/<task_id>', methods=['GET'])
|
230 |
def handle_get_task_order(task_id):
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
234 |
|
235 |
# Flaskルート
|
236 |
# ルートパスのGETリクエストに対するハンドラ
|
|
|
1 |
+
from flask import Flask, request, render_template, send_file, jsonify, send_from_directory, session, copy_current_request_context
|
2 |
from flask_socketio import SocketIO, join_room, leave_room, close_room, rooms, disconnect
|
3 |
from flask_cors import CORS
|
4 |
from flask_limiter import Limiter
|
|
|
23 |
from geventwebsocket.handler import WebSocketHandler
|
24 |
|
25 |
app = Flask(__name__)
|
|
|
26 |
CORS(app)
|
27 |
socketio = SocketIO(app, cors_allowed_origins="*")
|
28 |
+
# クライアントIDとルーム情報を保存するグローバル辞書
|
29 |
+
client_rooms = {}
|
30 |
|
31 |
# レート制限の設定
|
32 |
limiter = Limiter(
|
|
|
42 |
task_futures = {}
|
43 |
|
44 |
# ThreadPoolExecutorの作成
|
45 |
+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) # 8vCPUのインスタンスを使用
|
46 |
+
|
47 |
+
gpu_lock = threading.Lock()
|
48 |
|
49 |
class Task:
|
50 |
+
def __init__(self, task_id, mode, weight1, weight2, file_data, client_ip, client_id):
|
51 |
self.task_id = task_id
|
52 |
self.mode = mode
|
53 |
self.weight1 = weight1
|
|
|
56 |
self.cancel_flag = False
|
57 |
self.client_ip = client_ip
|
58 |
self.is_processing = False
|
59 |
+
self.client_id = client_id
|
60 |
|
61 |
# キューの状態を通知
|
62 |
def update_queue_status(message):
|
63 |
+
socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message}, namespace='/demo')
|
64 |
|
65 |
def process_task(task):
|
66 |
try:
|
|
|
|
|
|
|
67 |
task.is_processing = True
|
68 |
# ファイルデータをPIL Imageに変換
|
69 |
image = Image.open(io.BytesIO(task.file_data))
|
|
|
74 |
return
|
75 |
|
76 |
# 画像処理ロジックを呼び出す
|
77 |
+
# GPU処理部分
|
78 |
+
with gpu_lock:
|
79 |
+
sotai_image, sketch_image = process_image_as_base64(image, task.mode, task.weight1, task.weight2)
|
80 |
|
81 |
# キャンセルチェック
|
82 |
if task.cancel_flag:
|
83 |
return
|
84 |
+
|
85 |
+
# クライアントIDをリクエストヘッダーから取得(クライアント側で設定する必要があります)
|
86 |
+
client_id = task.client_id
|
87 |
+
if client_id and client_id in client_rooms:
|
88 |
+
room = client_rooms[client_id]
|
89 |
+
# ルームにメッセージをemit
|
90 |
+
socketio.emit('task_complete', {
|
91 |
+
'task_id': task.task_id,
|
92 |
+
'sotai_image': sotai_image,
|
93 |
+
'sketch_image': sketch_image
|
94 |
+
}, to=room, namespace='/demo')
|
95 |
+
|
96 |
except Exception as e:
|
97 |
print(f"Task error: {str(e)}")
|
98 |
if not task.cancel_flag:
|
99 |
+
client_id = task.client_id
|
100 |
+
room = client_rooms[client_id]
|
101 |
+
socketio.emit('task_error', {'task_id': task.task_id, 'error': str(e)}, to=room, namespace='/demo')
|
102 |
finally:
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
# タスク数をデクリメント
|
104 |
client_ip = task.client_ip
|
105 |
tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) - 1
|
106 |
+
print(f'Task {task.task_id} completed')
|
107 |
+
task.is_processing = False
|
108 |
+
if task.task_id in active_tasks.keys():
|
109 |
+
del active_tasks[task.task_id]
|
110 |
+
if task.task_id in task_futures.keys():
|
111 |
+
del task_futures[task.task_id]
|
112 |
+
|
113 |
update_queue_status('Task completed or cancelled')
|
114 |
|
115 |
def worker():
|
116 |
while True:
|
117 |
try:
|
118 |
task = task_queue.get()
|
119 |
+
if task.task_id in active_tasks.keys():
|
|
|
120 |
future = executor.submit(process_task, task)
|
121 |
task_futures[task.task_id] = future
|
122 |
except Exception as e:
|
|
|
131 |
# グローバル変数を使用して接続数とタスク数を管理
|
132 |
connected_clients = 0
|
133 |
tasks_per_client = {}
|
134 |
+
@socketio.on('connect', namespace='/demo')
|
135 |
def handle_connect(auth):
|
136 |
+
client_id = request.sid
|
137 |
+
room = f"room_{client_id}" # クライアントごとに一意のルーム名を生成
|
138 |
+
join_room(room)
|
139 |
+
client_rooms[client_id] = room
|
140 |
+
print(f"Client {client_id} connected and joined room {room}")
|
141 |
+
|
142 |
global connected_clients
|
143 |
connected_clients += 1
|
144 |
|
145 |
+
@socketio.on('disconnect' )
|
146 |
def handle_disconnect():
|
147 |
+
client_id = request.sid
|
148 |
+
if client_id in client_rooms:
|
149 |
+
room = client_rooms[client_id]
|
150 |
+
leave_room(room)
|
151 |
+
del client_rooms[client_id]
|
152 |
+
print(f"Client {client_id} disconnected and removed from room {room}")
|
153 |
+
|
154 |
global connected_clients
|
155 |
connected_clients -= 1
|
156 |
# キャンセル処理:接続が切断された場合、そのクライアントに関連するタスクをキャンセル。ただし、1番目で処理中のタスクはキャンセルしない
|
157 |
client_ip = get_remote_address()
|
158 |
+
for task_id, task in active_tasks.items():
|
159 |
if task.client_ip == client_ip and not task.is_processing:
|
160 |
task.cancel_flag = True
|
161 |
if task_id in task_futures:
|
|
|
172 |
|
173 |
# クライアントIPアドレスを取得
|
174 |
client_ip = get_remote_address()
|
|
|
|
|
175 |
# 同一IPからの同時タスク数を制限
|
176 |
if tasks_per_client.get(client_ip, 0) >= 2:
|
177 |
return jsonify({'error': 'Maximum number of concurrent tasks reached'}), 429
|
|
|
190 |
# ファイルデータをバイト列として保存
|
191 |
file_data = file.read()
|
192 |
|
193 |
+
client_id = request.headers.get('X-Client-ID')
|
194 |
+
task = Task(task_id, mode, weight1, weight2, file_data, client_ip, client_id)
|
195 |
task_queue.put(task)
|
196 |
active_tasks[task_id] = task
|
197 |
|
198 |
# 同一IPからのタスク数をインクリメント
|
199 |
tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) + 1
|
200 |
|
201 |
+
update_queue_status('Task submitted') # すべてに通知
|
202 |
|
203 |
queue_size = task_queue.qsize()
|
204 |
+
task_order = get_active_task_order(task_id)
|
205 |
+
return jsonify({'task_id': task_id, 'task_order': task_order, 'queue_size': queue_size})
|
206 |
|
207 |
@app.route('/cancel_task/<task_id>', methods=['POST'])
|
208 |
def cancel_task(task_id):
|
209 |
# クライアントIPアドレスを取得
|
210 |
client_ip = get_remote_address()
|
211 |
|
212 |
+
if task_id in active_tasks.keys():
|
213 |
task = active_tasks[task_id]
|
214 |
# タスクの所有者を確認(IPアドレスで簡易的に判断)
|
215 |
if task.client_ip != client_ip:
|
216 |
return jsonify({'error': 'Unauthorized to cancel this task'}), 403
|
217 |
task.cancel_flag = True
|
218 |
+
if task_id in task_futures.keys():
|
219 |
task_futures[task_id].cancel()
|
220 |
del task_futures[task_id]
|
221 |
del active_tasks[task_id]
|
|
|
234 |
|
235 |
@app.route('/task_status/<task_id>', methods=['GET'])
|
236 |
def task_status(task_id):
|
237 |
+
try:
|
238 |
+
if task_id in active_tasks.keys():
|
239 |
+
task = active_tasks[task_id]
|
240 |
+
return jsonify({'task_id': task_id, 'is_processing': task.is_processing})
|
241 |
+
else:
|
242 |
+
return jsonify({'task_id': task_id, 'is_processing': False})
|
243 |
+
except Exception as e:
|
244 |
+
return jsonify({'error': str(e)}), 500
|
245 |
|
246 |
def get_active_task_order(task_id):
|
247 |
+
non_processing_task_ids = [tid for tid, task in active_tasks.items() if not task.is_processing]
|
248 |
+
return non_processing_task_ids.index(task_id) if task_id in non_processing_task_ids else 0
|
249 |
|
250 |
# get_task_orderイベントハンドラー
|
251 |
@app.route('/get_task_order/<task_id>', methods=['GET'])
|
252 |
def handle_get_task_order(task_id):
|
253 |
+
print(f'Active tasks order: {task_id}, Active tasks: {active_tasks.keys()}')
|
254 |
+
if task_id in active_tasks.keys():
|
255 |
+
return jsonify({'task_order': get_active_task_order(task_id)})
|
256 |
+
else:
|
257 |
+
return jsonify({'task_order': 0})
|
258 |
|
259 |
# Flaskルート
|
260 |
# ルートパスのGETリクエストに対するハンドラ
|