WebashalarForML commited on
Commit
45c7975
·
verified ·
1 Parent(s): 5008817

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -172
app.py CHANGED
@@ -1,173 +1,224 @@
1
- from flask import Flask, render_template, request, session, redirect, url_for, flash, send_from_directory
2
- import os
3
- import secrets
4
- from werkzeug.utils import secure_filename
5
- import sys
6
- import shutil
7
-
8
- sys.path.append(os.path.dirname(__file__))
9
- import inference2 # Import your refactored inference script
10
-
11
- app = Flask(__name__)
12
- app.secret_key = os.urandom(24)
13
- app.config['UPLOAD_FOLDER'] = 'uploads'
14
- app.config['RESULTS_FOLDER'] = 'results' # This directory is NOT inside static
15
- app.config['CHECKPOINTS_FOLDER'] = 'checkpoints'
16
- app.config['TEMP_FOLDER'] = 'temp'
17
-
18
- ALLOWED_FACE_EXTENSIONS = {'png', 'jpg', 'jpeg', 'mp4', 'avi', 'mov'}
19
- ALLOWED_AUDIO_EXTENSIONS = {'wav', 'mp3', 'aac', 'flac'}
20
- ALLOWED_MODEL_EXTENSIONS = {'pth', 'pt'}
21
-
22
- os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
23
- os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True)
24
- os.makedirs(app.config['CHECKPOINTS_FOLDER'], exist_ok=True)
25
- os.makedirs(app.config['TEMP_FOLDER'], exist_ok=True)
26
-
27
-
28
- def allowed_file(filename, allowed_extensions):
29
- return '.' in filename and \
30
- filename.rsplit('.', 1)[1].lower() in allowed_extensions
31
-
32
- @app.route('/')
33
- def index():
34
- theme = session.get('theme', 'dark')
35
- available_models = []
36
- try:
37
- model_files = [f for f in os.listdir(app.config['CHECKPOINTS_FOLDER'])
38
- if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)]
39
- available_models = sorted(model_files)
40
- except FileNotFoundError:
41
- # flash("Checkpoints folder not found. Please create a 'checkpoints' directory.", 'error') # Messages removed
42
- pass
43
- except Exception as e:
44
- # flash(f"Error loading models: {e}", 'error') # Messages removed
45
- pass
46
- return render_template('index.html', theme=theme, models=available_models)
47
-
48
- @app.route('/toggle_theme')
49
- def toggle_theme():
50
- current_theme = session.get('theme', 'dark')
51
- if current_theme == 'dark':
52
- session['theme'] = 'light'
53
- else:
54
- session['theme'] = 'dark'
55
- return redirect(request.referrer or url_for('index'))
56
-
57
- @app.route('/infer', methods=['POST'])
58
- def infer():
59
- if request.method == 'POST':
60
- if 'face_file' not in request.files or 'audio_file' not in request.files:
61
- # flash('Both face and audio files are required.', 'error') # Messages removed
62
- return redirect(url_for('index'))
63
-
64
- face_file = request.files['face_file']
65
- audio_file = request.files['audio_file']
66
- selected_model = request.form.get('model_select')
67
-
68
- if face_file.filename == '' or audio_file.filename == '':
69
- # flash('No selected file for face or audio.', 'error') # Messages removed
70
- return redirect(url_for('index'))
71
-
72
- if not selected_model:
73
- # flash('No model selected.', 'error') # Messages removed
74
- return redirect(url_for('index'))
75
-
76
- if not allowed_file(face_file.filename, ALLOWED_FACE_EXTENSIONS):
77
- # flash('Invalid face file type. Allowed: png, jpg, jpeg, mp4, avi, mov', 'error') # Messages removed
78
- return redirect(url_for('index'))
79
- if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
80
- # flash('Invalid audio file type. Allowed: wav, mp3, aac, flac', 'error') # Messages removed
81
- return redirect(url_for('index'))
82
-
83
- face_filename = secure_filename(face_file.filename)
84
- audio_filename = secure_filename(audio_file.filename)
85
-
86
- face_uuid = secrets.token_hex(8)
87
- audio_uuid = secrets.token_hex(8)
88
-
89
- face_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{face_uuid}_{face_filename}")
90
- audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{audio_uuid}_{audio_filename}")
91
-
92
- try:
93
- face_file.save(face_path)
94
- audio_file.save(audio_path)
95
- except Exception as e:
96
- # flash(f"Error saving uploaded files: {e}", 'error') # Messages removed
97
- return redirect(url_for('index'))
98
-
99
- checkpoint_path = os.path.join(app.config['CHECKPOINTS_FOLDER'], selected_model)
100
- output_video_name = f"result_{face_uuid}.mp4"
101
-
102
- try:
103
- # flash('Starting inference... This may take a while.', 'info') # Messages removed
104
- generated_video_path = inference2.run_inference(
105
- checkpoint_path=checkpoint_path,
106
- face_path=face_path,
107
- audio_path=audio_path,
108
- output_filename=output_video_name,
109
- static=request.form.get('static_input') == 'on',
110
- fps=float(request.form.get('fps', 25.0)),
111
- resize_factor=int(request.form.get('resize_factor', 1)),
112
- rotate=request.form.get('rotate') == 'on',
113
- nosmooth=request.form.get('nosmooth') == 'on',
114
- pads=[0, 10, 0, 0],
115
- crop=[0, -1, 0, -1],
116
- box=[-1, -1, -1, -1],
117
- face_det_batch_size=16,
118
- wav2lip_batch_size=128,
119
- img_size=96
120
- )
121
- # flash('Inference completed successfully!', 'success') # Messages removed
122
-
123
- # Redirect to the page that renders result.html
124
- return redirect(url_for('render_result_page', filename=os.path.basename(generated_video_path)))
125
-
126
- except ValueError as e:
127
- # flash(f"Inference Error: {e}", 'error') # Messages removed
128
- pass
129
- except RuntimeError as e:
130
- # flash(f"Runtime Error during inference: {e}", 'error') # Messages removed
131
- pass
132
- except Exception as e:
133
- # flash(f"An unexpected error occurred: {e}", 'error') # Messages removed
134
- pass
135
- finally:
136
- if os.path.exists(face_path):
137
- os.remove(face_path)
138
- if os.path.exists(audio_path):
139
- os.remove(audio_path)
140
-
141
- return redirect(url_for('index'))
142
-
143
- # Route to render the result.html template
144
- @app.route('/result_page/<filename>')
145
- def render_result_page(filename):
146
- theme = session.get('theme', 'dark')
147
- # Check if the file actually exists before rendering
148
- if not os.path.exists(os.path.join(app.config['RESULTS_FOLDER'], filename)):
149
- # If the video isn't found, redirect or show an error
150
- # Consider a dedicated error page or a message within index.html if no flashes are used
151
- return redirect(url_for('index'))
152
- return render_template('result.html', theme=theme, video_filename=filename)
153
-
154
-
155
- # Route to serve the video file itself (used by <video src="...">)
156
- @app.route('/results/<path:filename>') # Use <path:filename> to handle potential subdirectories in filename (though not needed here)
157
- def serve_result_video(filename):
158
- # This route is solely for serving the video file
159
- return send_from_directory(app.config['RESULTS_FOLDER'], filename)
160
-
161
- # Route to download the video file
162
- @app.route('/download/<filename>') # Changed to /download/ for clarity
163
- def download_result(filename):
164
- return send_from_directory(app.config['RESULTS_FOLDER'], filename, as_attachment=True)
165
-
166
-
167
- if __name__ == '__main__':
168
- os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
169
- os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True)
170
- os.makedirs(app.config['CHECKPOINTS_FOLDER'], exist_ok=True)
171
- os.makedirs(app.config['TEMP_FOLDER'], exist_ok=True)
172
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  app.run(debug=True)
 
1
+ from flask import Flask, render_template, request, session, redirect, url_for, flash, send_from_directory
2
+ import os
3
+ import secrets
4
+ from werkzeug.utils import secure_filename
5
+ import sys
6
+ import shutil
7
+ import logging # Import the logging module
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO, # Set the minimum logging level
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
12
+ handlers=[
13
+ logging.FileHandler("app.log"), # Log to a file
14
+ logging.StreamHandler() # Log to console
15
+ ])
16
+ logger = logging.getLogger(__name__)
17
+
18
+ sys.path.append(os.path.dirname(__file__))
19
+ import inference2 # Import your refactored inference script
20
+
21
+ app = Flask(__name__)
22
+ app.secret_key = os.urandom(24)
23
+ app.config['UPLOAD_FOLDER'] = 'uploads'
24
+ app.config['RESULTS_FOLDER'] = 'results'
25
+ app.config['CHECKPOINTS_FOLDER'] = 'checkpoints'
26
+ app.config['TEMP_FOLDER'] = 'temp'
27
+
28
+ ALLOWED_FACE_EXTENSIONS = {'png', 'jpg', 'jpeg', 'mp4', 'avi', 'mov'}
29
+ ALLOWED_AUDIO_EXTENSIONS = {'wav', 'mp3', 'aac', 'flac'}
30
+ ALLOWED_MODEL_EXTENSIONS = {'pth', 'pt'}
31
+
32
+ # Ensure directories exist
33
+ try:
34
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
35
+ os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True)
36
+ os.makedirs(app.config['CHECKPOINTS_FOLDER'], exist_ok=True)
37
+ os.makedirs(app.config['TEMP_FOLDER'], exist_ok=True)
38
+ logger.info("All necessary directories ensured to exist.")
39
+ except OSError as e:
40
+ logger.critical(f"Error creating essential directories: {e}")
41
+ # Depending on the severity, you might want to exit or disable functionality
42
+ sys.exit(1) # Exit if essential directories cannot be created
43
+
44
+ def allowed_file(filename, allowed_extensions):
45
+ """Checks if a file's extension is allowed."""
46
+ return '.' in filename and \
47
+ filename.rsplit('.', 1)[1].lower() in allowed_extensions
48
+
49
+ @app.route('/')
50
+ def index():
51
+ """Renders the main page, displaying available models."""
52
+ theme = session.get('theme', 'dark')
53
+ available_models = []
54
+ try:
55
+ model_files = [f for f in os.listdir(app.config['CHECKPOINTS_FOLDER'])
56
+ if allowed_file(f, ALLOWED_MODEL_EXTENSIONS)]
57
+ available_models = sorted(model_files)
58
+ logger.info(f"Successfully loaded {len(available_models)} models.")
59
+ except FileNotFoundError:
60
+ logger.warning(f"Checkpoints folder '{app.config['CHECKPOINTS_FOLDER']}' not found. Please create it.")
61
+ except Exception as e:
62
+ logger.error(f"Error loading models from '{app.config['CHECKPOINTS_FOLDER']}': {e}")
63
+ return render_template('index.html', theme=theme, models=available_models)
64
+
65
+ @app.route('/toggle_theme')
66
+ def toggle_theme():
67
+ """Toggles the session theme between dark and light."""
68
+ current_theme = session.get('theme', 'dark')
69
+ if current_theme == 'dark':
70
+ session['theme'] = 'light'
71
+ logger.info("Theme toggled to light.")
72
+ else:
73
+ session['theme'] = 'dark'
74
+ logger.info("Theme toggled to dark.")
75
+ return redirect(request.referrer or url_for('index'))
76
+
77
+ @app.route('/infer', methods=['POST'])
78
+ def infer():
79
+ """Handles the inference request, processing uploaded files and running the model."""
80
+ if request.method == 'POST':
81
+ logger.info("Inference request received.")
82
+
83
+ # Check for file presence
84
+ if 'face_file' not in request.files or 'audio_file' not in request.files:
85
+ logger.warning("Both face and audio files are required for inference.")
86
+ return redirect(url_for('index'))
87
+
88
+ face_file = request.files['face_file']
89
+ audio_file = request.files['audio_file']
90
+ selected_model = request.form.get('model_select')
91
+
92
+ if face_file.filename == '' or audio_file.filename == '':
93
+ logger.warning("No selected file for face or audio provided.")
94
+ return redirect(url_for('index'))
95
+
96
+ if not selected_model:
97
+ logger.warning("No model selected for inference.")
98
+ return redirect(url_for('index'))
99
+
100
+ # Validate file types
101
+ if not allowed_file(face_file.filename, ALLOWED_FACE_EXTENSIONS):
102
+ logger.warning(f"Invalid face file type: {face_file.filename}. Allowed: {', '.join(ALLOWED_FACE_EXTENSIONS)}")
103
+ return redirect(url_for('index'))
104
+ if not allowed_file(audio_file.filename, ALLOWED_AUDIO_EXTENSIONS):
105
+ logger.warning(f"Invalid audio file type: {audio_file.filename}. Allowed: {', '.join(ALLOWED_AUDIO_EXTENSIONS)}")
106
+ return redirect(url_for('index'))
107
+
108
+ face_filename = secure_filename(face_file.filename)
109
+ audio_filename = secure_filename(audio_file.filename)
110
+
111
+ face_uuid = secrets.token_hex(8)
112
+ audio_uuid = secrets.token_hex(8)
113
+
114
+ face_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{face_uuid}_{face_filename}")
115
+ audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{audio_uuid}_{audio_filename}")
116
+
117
+ try:
118
+ face_file.save(face_path)
119
+ audio_file.save(audio_path)
120
+ logger.info(f"Successfully saved uploaded files: {face_path}, {audio_path}")
121
+ except Exception as e:
122
+ logger.error(f"Error saving uploaded files: {e}")
123
+ return redirect(url_for('index'))
124
+
125
+ checkpoint_path = os.path.join(app.config['CHECKPOINTS_FOLDER'], selected_model)
126
+ output_video_name = f"result_{face_uuid}.mp4"
127
+
128
+ try:
129
+ logger.info(f"Starting inference with model: {selected_model}, face: {face_filename}, audio: {audio_filename}")
130
+ generated_video_path = inference2.run_inference(
131
+ checkpoint_path=checkpoint_path,
132
+ face_path=face_path,
133
+ audio_path=audio_path,
134
+ output_filename=output_video_name,
135
+ static=request.form.get('static_input') == 'on',
136
+ fps=float(request.form.get('fps', 25.0)),
137
+ resize_factor=int(request.form.get('resize_factor', 1)),
138
+ rotate=request.form.get('rotate') == 'on',
139
+ nosmooth=request.form.get('nosmooth') == 'on',
140
+ pads=[0, 10, 0, 0],
141
+ crop=[0, -1, 0, -1],
142
+ box=[-1, -1, -1, -1],
143
+ face_det_batch_size=16,
144
+ wav2lip_batch_size=128,
145
+ img_size=96
146
+ )
147
+ logger.info(f"Inference completed successfully. Generated video: {generated_video_path}")
148
+ return redirect(url_for('render_result_page', filename=os.path.basename(generated_video_path)))
149
+
150
+ except ValueError as e:
151
+ logger.error(f"Inference ValueError: {e}")
152
+ except RuntimeError as e:
153
+ logger.error(f"Runtime Error during inference: {e}")
154
+ except Exception as e:
155
+ logger.critical(f"An unexpected error occurred during inference: {e}", exc_info=True) # exc_info=True to log traceback
156
+ finally:
157
+ # Clean up uploaded files regardless of inference success or failure
158
+ if os.path.exists(face_path):
159
+ try:
160
+ os.remove(face_path)
161
+ logger.info(f"Cleaned up uploaded face file: {face_path}")
162
+ except OSError as e:
163
+ logger.error(f"Error removing face file {face_path}: {e}")
164
+ if os.path.exists(audio_path):
165
+ try:
166
+ os.remove(audio_path)
167
+ logger.info(f"Cleaned up uploaded audio file: {audio_path}")
168
+ except OSError as e:
169
+ logger.error(f"Error removing audio file {audio_path}: {e}")
170
+
171
+ return redirect(url_for('index'))
172
+
173
+
174
+ ## Result Handling and File Serving
175
+
176
+ @app.route('/result_page/<filename>')
177
+ def render_result_page(filename):
178
+ """Renders the result page with the generated video."""
179
+ theme = session.get('theme', 'dark')
180
+ result_video_path = os.path.join(app.config['RESULTS_FOLDER'], filename)
181
+ if not os.path.exists(result_video_path):
182
+ logger.warning(f"Attempted to access non-existent result video: {result_video_path}")
183
+ return redirect(url_for('index'))
184
+ logger.info(f"Rendering result page for video: {filename}")
185
+ return render_template('result.html', theme=theme, video_filename=filename)
186
+
187
+ @app.route('/results/<path:filename>')
188
+ def serve_result_video(filename):
189
+ """Serves the generated video file from the results folder."""
190
+ logger.debug(f"Serving result video: {filename}")
191
+ try:
192
+ return send_from_directory(app.config['RESULTS_FOLDER'], filename)
193
+ except Exception as e:
194
+ logger.error(f"Error serving result video {filename}: {e}")
195
+ return "Error serving file", 500
196
+
197
+ @app.route('/download/<filename>')
198
+ def download_result(filename):
199
+ """Allows downloading the generated video file."""
200
+ logger.info(f"Download request for video: {filename}")
201
+ try:
202
+ return send_from_directory(app.config['RESULTS_FOLDER'], filename, as_attachment=True)
203
+ except Exception as e:
204
+ logger.error(f"Error downloading video {filename}: {e}")
205
+ return "Error downloading file", 500
206
+
207
+ ## Application Initialization
208
+
209
+ if __name__ == '__main__':
210
+ # Directories are already ensured at the top level of the script,
211
+ # but re-checking here for robustness in case app.run is called differently.
212
+ try:
213
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
214
+ os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True)
215
+ os.makedirs(app.config['CHECKPOINTS_FOLDER'], exist_ok=True)
216
+ os.makedirs(app.config['TEMP_FOLDER'], exist_ok=True)
217
+ logger.info("Application directories confirmed at startup.")
218
+ except OSError as e:
219
+ logger.critical(f"Critical error during startup: Could not create necessary directories: {e}")
220
+ sys.exit(1) # Exit if directories cannot be created
221
+
222
+ logger.info("Starting Flask application...")
223
+ # In a production environment, debug=True should be False
224
  app.run(debug=True)