yeq6x commited on
Commit
b38a367
·
1 Parent(s): 0db4cac

bug naosita

Browse files
Files changed (1) hide show
  1. app.py +73 -49
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from flask import Flask, request, render_template, send_file, jsonify, send_from_directory, session
2
  from flask_socketio import SocketIO, join_room, leave_room, close_room, rooms, disconnect
3
  from flask_cors import CORS
4
  from flask_limiter import Limiter
@@ -23,9 +23,10 @@ from gevent import pywsgi
23
  from geventwebsocket.handler import WebSocketHandler
24
 
25
  app = Flask(__name__)
26
- app.secret_key = 'user'
27
  CORS(app)
28
  socketio = SocketIO(app, cors_allowed_origins="*")
 
 
29
 
30
  # レート制限の設定
31
  limiter = Limiter(
@@ -41,10 +42,12 @@ active_tasks = {}
41
  task_futures = {}
42
 
43
  # ThreadPoolExecutorの作成
44
- executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) # 8vCPUのインスタンスを使用
 
 
45
 
46
  class Task:
47
- def __init__(self, task_id, mode, weight1, weight2, file_data, client_ip):
48
  self.task_id = task_id
49
  self.mode = mode
50
  self.weight1 = weight1
@@ -53,16 +56,14 @@ class Task:
53
  self.cancel_flag = False
54
  self.client_ip = client_ip
55
  self.is_processing = False
 
56
 
57
  # キューの状態を通知
58
  def update_queue_status(message):
59
- socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message})
60
 
61
  def process_task(task):
62
  try:
63
- update_queue_status("sid")
64
- update_queue_status(str("sid" in session))
65
- client_id = session.get("sid")
66
  task.is_processing = True
67
  # ファイルデータをPIL Imageに変換
68
  image = Image.open(io.BytesIO(task.file_data))
@@ -73,40 +74,49 @@ def process_task(task):
73
  return
74
 
75
  # 画像処理ロジックを呼び出す
76
- sotai_image, sketch_image = process_image_as_base64(image, task.mode, task.weight1, task.weight2)
 
 
77
 
78
  # キャンセルチェック
79
  if task.cancel_flag:
80
  return
81
-
82
- socketio.emit('task_complete', {
83
- 'task_id': task.task_id,
84
- 'sotai_image': sotai_image,
85
- 'sketch_image': sketch_image
86
- }, to=client_id)
 
 
 
 
 
 
87
  except Exception as e:
88
  print(f"Task error: {str(e)}")
89
  if not task.cancel_flag:
90
- socketio.emit('task_error', {'task_id': task.task_id, 'error': str(e)}, to=client_id)
 
 
91
  finally:
92
- task.is_processing = False
93
- if task.task_id in active_tasks:
94
- del active_tasks[task.task_id]
95
- if task.task_id in task_futures:
96
- del task_futures[task.task_id]
97
-
98
  # タスク数をデクリメント
99
  client_ip = task.client_ip
100
  tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) - 1
101
-
 
 
 
 
 
 
102
  update_queue_status('Task completed or cancelled')
103
 
104
  def worker():
105
  while True:
106
  try:
107
  task = task_queue.get()
108
- if task.task_id in active_tasks:
109
- print(f"Processing task {task.task_id}")
110
  future = executor.submit(process_task, task)
111
  task_futures[task.task_id] = future
112
  except Exception as e:
@@ -121,23 +131,31 @@ threading.Thread(target=worker, daemon=True).start()
121
  # グローバル変数を使用して接続数とタスク数を管理
122
  connected_clients = 0
123
  tasks_per_client = {}
124
- @socketio.on('connect', namespace='/')
125
  def handle_connect(auth):
126
- session["sid"] = request.sid
127
- client_id = request.sid # クライアントIDを取得
128
- join_room(client_id) # クライアントを自身のルームに入れる
 
 
 
129
  global connected_clients
130
  connected_clients += 1
131
 
132
- @socketio.on('disconnect' , namespace='/')
133
  def handle_disconnect():
134
- client_id = request.sid # クライアントIDを取得
135
- leave_room(client_id) # クライアントをルームから出す
 
 
 
 
 
136
  global connected_clients
137
  connected_clients -= 1
138
  # キャンセル処理:接続が切断された場合、そのクライアントに関連するタスクをキャンセル。ただし、1番目で処理中のタスクはキャンセルしない
139
  client_ip = get_remote_address()
140
- for task_id, task in list(active_tasks.items()):
141
  if task.client_ip == client_ip and not task.is_processing:
142
  task.cancel_flag = True
143
  if task_id in task_futures:
@@ -154,8 +172,6 @@ def submit_task():
154
 
155
  # クライアントIPアドレスを取得
156
  client_ip = get_remote_address()
157
- # client_id = session.get("sid")
158
-
159
  # 同一IPからの同時タスク数を制限
160
  if tasks_per_client.get(client_ip, 0) >= 2:
161
  return jsonify({'error': 'Maximum number of concurrent tasks reached'}), 429
@@ -174,30 +190,32 @@ def submit_task():
174
  # ファイルデータをバイト列として保存
175
  file_data = file.read()
176
 
177
- task = Task(task_id, mode, weight1, weight2, file_data, client_ip)
 
178
  task_queue.put(task)
179
  active_tasks[task_id] = task
180
 
181
  # 同一IPからのタスク数をインクリメント
182
  tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) + 1
183
 
184
- update_queue_status('Task submitted')
185
 
186
  queue_size = task_queue.qsize()
187
- return jsonify({'task_id': task_id, 'queue_size': queue_size})
 
188
 
189
  @app.route('/cancel_task/<task_id>', methods=['POST'])
190
  def cancel_task(task_id):
191
  # クライアントIPアドレスを取得
192
  client_ip = get_remote_address()
193
 
194
- if task_id in active_tasks:
195
  task = active_tasks[task_id]
196
  # タスクの所有者を確認(IPアドレスで簡易的に判断)
197
  if task.client_ip != client_ip:
198
  return jsonify({'error': 'Unauthorized to cancel this task'}), 403
199
  task.cancel_flag = True
200
- if task_id in task_futures:
201
  task_futures[task_id].cancel()
202
  del task_futures[task_id]
203
  del active_tasks[task_id]
@@ -216,21 +234,27 @@ def cancel_task(task_id):
216
 
217
  @app.route('/task_status/<task_id>', methods=['GET'])
218
  def task_status(task_id):
219
- task = active_tasks.get(task_id, None)
220
- if task:
221
- return jsonify({'task_id': task_id, 'is_processing': task.is_processing})
222
- return jsonify({'error': 'Task not found'}), 404
 
 
 
 
223
 
224
  def get_active_task_order(task_id):
225
- non_processing_tasks = [tid for tid, task in active_tasks.items() if not task.is_processing]
226
- return non_processing_tasks.index(task_id) if task_id in non_processing_tasks else 0
227
 
228
  # get_task_orderイベントハンドラー
229
  @app.route('/get_task_order/<task_id>', methods=['GET'])
230
  def handle_get_task_order(task_id):
231
- if task_id not in active_tasks:
232
- return jsonify({'error': 'Task not found'}), 404
233
- return jsonify({'task_order': get_active_task_order(task_id)})
 
 
234
 
235
  # Flaskルート
236
  # ルートパスのGETリクエストに対するハンドラ
 
1
+ from flask import Flask, request, render_template, send_file, jsonify, send_from_directory, session, copy_current_request_context
2
  from flask_socketio import SocketIO, join_room, leave_room, close_room, rooms, disconnect
3
  from flask_cors import CORS
4
  from flask_limiter import Limiter
 
23
  from geventwebsocket.handler import WebSocketHandler
24
 
25
  app = Flask(__name__)
 
26
  CORS(app)
27
  socketio = SocketIO(app, cors_allowed_origins="*")
28
+ # クライアントIDとルーム情報を保存するグローバル辞書
29
+ client_rooms = {}
30
 
31
  # レート制限の設定
32
  limiter = Limiter(
 
42
  task_futures = {}
43
 
44
  # ThreadPoolExecutorの作成
45
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) # 8vCPUのインスタンスを使用
46
+
47
+ gpu_lock = threading.Lock()
48
 
49
  class Task:
50
+ def __init__(self, task_id, mode, weight1, weight2, file_data, client_ip, client_id):
51
  self.task_id = task_id
52
  self.mode = mode
53
  self.weight1 = weight1
 
56
  self.cancel_flag = False
57
  self.client_ip = client_ip
58
  self.is_processing = False
59
+ self.client_id = client_id
60
 
61
  # キューの状態を通知
62
  def update_queue_status(message):
63
+ socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message}, namespace='/demo')
64
 
65
  def process_task(task):
66
  try:
 
 
 
67
  task.is_processing = True
68
  # ファイルデータをPIL Imageに変換
69
  image = Image.open(io.BytesIO(task.file_data))
 
74
  return
75
 
76
  # 画像処理ロジックを呼び出す
77
+ # GPU処理部分
78
+ with gpu_lock:
79
+ sotai_image, sketch_image = process_image_as_base64(image, task.mode, task.weight1, task.weight2)
80
 
81
  # キャンセルチェック
82
  if task.cancel_flag:
83
  return
84
+
85
+ # クライアントIDをリクエストヘッダーから取得(クライアント側で設定する必要があります)
86
+ client_id = task.client_id
87
+ if client_id and client_id in client_rooms:
88
+ room = client_rooms[client_id]
89
+ # ルームにメッセージをemit
90
+ socketio.emit('task_complete', {
91
+ 'task_id': task.task_id,
92
+ 'sotai_image': sotai_image,
93
+ 'sketch_image': sketch_image
94
+ }, to=room, namespace='/demo')
95
+
96
  except Exception as e:
97
  print(f"Task error: {str(e)}")
98
  if not task.cancel_flag:
99
+ client_id = task.client_id
100
+ room = client_rooms[client_id]
101
+ socketio.emit('task_error', {'task_id': task.task_id, 'error': str(e)}, to=room, namespace='/demo')
102
  finally:
 
 
 
 
 
 
103
  # タスク数をデクリメント
104
  client_ip = task.client_ip
105
  tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) - 1
106
+ print(f'Task {task.task_id} completed')
107
+ task.is_processing = False
108
+ if task.task_id in active_tasks.keys():
109
+ del active_tasks[task.task_id]
110
+ if task.task_id in task_futures.keys():
111
+ del task_futures[task.task_id]
112
+
113
  update_queue_status('Task completed or cancelled')
114
 
115
  def worker():
116
  while True:
117
  try:
118
  task = task_queue.get()
119
+ if task.task_id in active_tasks.keys():
 
120
  future = executor.submit(process_task, task)
121
  task_futures[task.task_id] = future
122
  except Exception as e:
 
131
  # グローバル変数を使用して接続数とタスク数を管理
132
  connected_clients = 0
133
  tasks_per_client = {}
134
+ @socketio.on('connect', namespace='/demo')
135
  def handle_connect(auth):
136
+ client_id = request.sid
137
+ room = f"room_{client_id}" # クライアントごとに一意のルーム名を生成
138
+ join_room(room)
139
+ client_rooms[client_id] = room
140
+ print(f"Client {client_id} connected and joined room {room}")
141
+
142
  global connected_clients
143
  connected_clients += 1
144
 
145
+ @socketio.on('disconnect' )
146
  def handle_disconnect():
147
+ client_id = request.sid
148
+ if client_id in client_rooms:
149
+ room = client_rooms[client_id]
150
+ leave_room(room)
151
+ del client_rooms[client_id]
152
+ print(f"Client {client_id} disconnected and removed from room {room}")
153
+
154
  global connected_clients
155
  connected_clients -= 1
156
  # キャンセル処理:接続が切断された場合、そのクライアントに関連するタスクをキャンセル。ただし、1番目で処理中のタスクはキャンセルしない
157
  client_ip = get_remote_address()
158
+ for task_id, task in active_tasks.items():
159
  if task.client_ip == client_ip and not task.is_processing:
160
  task.cancel_flag = True
161
  if task_id in task_futures:
 
172
 
173
  # クライアントIPアドレスを取得
174
  client_ip = get_remote_address()
 
 
175
  # 同一IPからの同時タスク数を制限
176
  if tasks_per_client.get(client_ip, 0) >= 2:
177
  return jsonify({'error': 'Maximum number of concurrent tasks reached'}), 429
 
190
  # ファイルデータをバイト列として保存
191
  file_data = file.read()
192
 
193
+ client_id = request.headers.get('X-Client-ID')
194
+ task = Task(task_id, mode, weight1, weight2, file_data, client_ip, client_id)
195
  task_queue.put(task)
196
  active_tasks[task_id] = task
197
 
198
  # 同一IPからのタスク数をインクリメント
199
  tasks_per_client[client_ip] = tasks_per_client.get(client_ip, 0) + 1
200
 
201
+ update_queue_status('Task submitted') # すべてに通知
202
 
203
  queue_size = task_queue.qsize()
204
+ task_order = get_active_task_order(task_id)
205
+ return jsonify({'task_id': task_id, 'task_order': task_order, 'queue_size': queue_size})
206
 
207
  @app.route('/cancel_task/<task_id>', methods=['POST'])
208
  def cancel_task(task_id):
209
  # クライアントIPアドレスを取得
210
  client_ip = get_remote_address()
211
 
212
+ if task_id in active_tasks.keys():
213
  task = active_tasks[task_id]
214
  # タスクの所有者を確認(IPアドレスで簡易的に判断)
215
  if task.client_ip != client_ip:
216
  return jsonify({'error': 'Unauthorized to cancel this task'}), 403
217
  task.cancel_flag = True
218
+ if task_id in task_futures.keys():
219
  task_futures[task_id].cancel()
220
  del task_futures[task_id]
221
  del active_tasks[task_id]
 
234
 
235
  @app.route('/task_status/<task_id>', methods=['GET'])
236
  def task_status(task_id):
237
+ try:
238
+ if task_id in active_tasks.keys():
239
+ task = active_tasks[task_id]
240
+ return jsonify({'task_id': task_id, 'is_processing': task.is_processing})
241
+ else:
242
+ return jsonify({'task_id': task_id, 'is_processing': False})
243
+ except Exception as e:
244
+ return jsonify({'error': str(e)}), 500
245
 
246
  def get_active_task_order(task_id):
247
+ non_processing_task_ids = [tid for tid, task in active_tasks.items() if not task.is_processing]
248
+ return non_processing_task_ids.index(task_id) if task_id in non_processing_task_ids else 0
249
 
250
  # get_task_orderイベントハンドラー
251
  @app.route('/get_task_order/<task_id>', methods=['GET'])
252
  def handle_get_task_order(task_id):
253
+ print(f'Active tasks order: {task_id}, Active tasks: {active_tasks.keys()}')
254
+ if task_id in active_tasks.keys():
255
+ return jsonify({'task_order': get_active_task_order(task_id)})
256
+ else:
257
+ return jsonify({'task_order': 0})
258
 
259
  # Flaskルート
260
  # ルートパスのGETリクエストに対するハンドラ