ddriscoll commited on
Commit
32d10ec
Β·
verified Β·
1 Parent(s): 10625de

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +536 -0
  2. requirements.txt +31 -0
app.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import time
4
+ import threading
5
+ import requests
6
+ import wikipedia
7
+ import torch
8
+ import cv2
9
+ import numpy as np
10
+ from io import BytesIO
11
+ from PIL import Image
12
+ import base64 # Added import
13
+
14
+ import gradio as gr
15
+ from ultralytics import YOLO
16
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
17
+ from diffusers import MarigoldDepthPipeline # Updated import for depth model
18
+ from realesrgan import RealESRGANer
19
+ from basicsr.archs.rrdbnet_arch import RRDBNet
20
+
21
+ # Set environment variable for PyTorch MPS fallback before importing torch
22
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
23
+
24
+ # Initialize Models
25
+ def initialize_models():
26
+ models = {}
27
+
28
+ # Device detection
29
+ if torch.cuda.is_available():
30
+ device = 'cuda'
31
+ elif torch.backends.mps.is_available():
32
+ device = 'mps'
33
+ else:
34
+ device = 'cpu'
35
+ models['device'] = device
36
+
37
+ print(f"Using device: {device}")
38
+
39
+ # Initialize the RoBERTa model for question answering
40
+ try:
41
+ models['qa_pipeline'] = pipeline(
42
+ "question-answering", model="deepset/roberta-base-squad2", device=0 if device == 'cuda' else -1)
43
+ print("RoBERTa QA pipeline initialized.")
44
+ except Exception as e:
45
+ print(f"Error initializing the RoBERTa model: {e}")
46
+ models['qa_pipeline'] = None
47
+
48
+ # Initialize the Gemma model
49
+ try:
50
+ models['gemma_tokenizer'] = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
51
+ models['gemma_model'] = AutoModelForCausalLM.from_pretrained(
52
+ "google/gemma-2-2b-it",
53
+ device_map="auto",
54
+ torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32
55
+ )
56
+ print("Gemma model initialized.")
57
+ except Exception as e:
58
+ print(f"Error initializing the Gemma model: {e}")
59
+ models['gemma_model'] = None
60
+
61
+ # Initialize the depth estimation model using MarigoldDepthPipeline exactly as per your sample
62
+ try:
63
+ if device == 'cuda':
64
+ models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained(
65
+ "prs-eth/marigold-depth-lcm-v1-0",
66
+ variant="fp16",
67
+ torch_dtype=torch.float16
68
+ ).to('cuda')
69
+ else:
70
+ # For CPU or MPS devices, keep on 'cpu' to avoid unsupported operators
71
+ models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained(
72
+ "prs-eth/marigold-depth-lcm-v1-0",
73
+ torch_dtype=torch.float32
74
+ ).to('cpu')
75
+ print("Depth estimation model initialized.")
76
+ except Exception as e:
77
+ error_message = f"Error initializing the depth estimation model: {e}"
78
+ print(error_message)
79
+ models['depth_pipe'] = None
80
+ models['depth_init_error'] = error_message # Store the error message
81
+
82
+ # Initialize the upscaling model
83
+ try:
84
+ upscaler_model_path = 'weights/RealESRGAN_x4plus.pth' # Ensure this path is correct
85
+ if not os.path.exists(upscaler_model_path):
86
+ print(f"Upscaling model weights not found at {upscaler_model_path}. Please download them.")
87
+ models['upscaler'] = None
88
+ else:
89
+ # Define the model architecture
90
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
91
+ num_block=23, num_grow_ch=32, scale=4)
92
+
93
+ # Initialize RealESRGANer
94
+ models['upscaler'] = RealESRGANer(
95
+ scale=4,
96
+ model_path=upscaler_model_path,
97
+ model=model,
98
+ pre_pad=0,
99
+ half=(device == 'cuda'),
100
+ device=device
101
+ )
102
+ print("Real-ESRGAN upscaler initialized.")
103
+ except Exception as e:
104
+ print(f"Error initializing the upscaling model: {e}")
105
+ models['upscaler'] = None
106
+
107
+ # Initialize YOLO model
108
+ try:
109
+ source_weights_path = "/Users/David/Downloads/WheelOfFortuneLab-DavidDriscoll/Eurybia1.3/mbari_315k_yolov8.pt"
110
+ if not os.path.exists(source_weights_path):
111
+ print(f"YOLO weights not found at {source_weights_path}. Please download them.")
112
+ models['yolo_model'] = None
113
+ else:
114
+ models['yolo_model'] = YOLO(source_weights_path)
115
+ print("YOLO model initialized.")
116
+ except Exception as e:
117
+ print(f"Error initializing YOLO model: {e}")
118
+ models['yolo_model'] = None
119
+
120
+ return models
121
+
122
+ models = initialize_models()
123
+
124
+ # Utility Functions
125
+ def search_class_description(class_name):
126
+ wikipedia.set_lang("en")
127
+ wikipedia.set_rate_limiting(True)
128
+ description = ""
129
+
130
+ try:
131
+ page = wikipedia.page(class_name)
132
+ if page:
133
+ description = page.content[:5000] # Get more content
134
+ except Exception as e:
135
+ print(f"Error fetching description for {class_name}: {e}")
136
+
137
+ return description
138
+
139
+ def search_class_image(class_name):
140
+ wikipedia.set_lang("en")
141
+ wikipedia.set_rate_limiting(True)
142
+ img_url = ""
143
+
144
+ try:
145
+ page = wikipedia.page(class_name)
146
+ if page:
147
+ for img in page.images:
148
+ if img.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')):
149
+ img_url = img
150
+ break
151
+ except Exception as e:
152
+ print(f"Error fetching image for {class_name}: {e}")
153
+
154
+ return img_url
155
+
156
+ def process_image(image):
157
+ if models['yolo_model'] is None:
158
+ return None, "YOLO model is not initialized.", "YOLO model is not initialized.", [], None
159
+
160
+ try:
161
+ if image is None:
162
+ return None, "No image uploaded.", "No image uploaded.", [], None
163
+
164
+ # Convert Gradio Image to OpenCV format
165
+ image_np = np.array(image)
166
+ if image_np.dtype != np.uint8:
167
+ image_np = image_np.astype(np.uint8)
168
+
169
+ if len(image_np.shape) != 3 or image_np.shape[2] != 3:
170
+ return None, "Invalid image format. Please upload a RGB image.", "Invalid image format. Please upload a RGB image.", [], None
171
+
172
+ image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
173
+
174
+ # Store the original image before drawing bounding boxes
175
+ original_image_cv = image_cv.copy()
176
+ original_image_pil = Image.fromarray(cv2.cvtColor(original_image_cv, cv2.COLOR_BGR2RGB))
177
+
178
+ # Perform YOLO prediction
179
+ results = models['yolo_model'].predict(
180
+ source=image_cv, conf=0.075)[0] # Lowered the threshold
181
+
182
+ bounding_boxes = []
183
+ image_processed = image_cv.copy()
184
+
185
+ if results.boxes is not None:
186
+ for box in results.boxes:
187
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
188
+ class_name = models['yolo_model'].names[int(box.cls)]
189
+ confidence = box.conf.item() * 100 # Convert to percentage
190
+
191
+ bounding_boxes.append({
192
+ "coords": (x1, y1, x2, y2),
193
+ "class_name": class_name,
194
+ "confidence": confidence
195
+ })
196
+
197
+ cv2.rectangle(image_processed, (x1, y1), (x2, y2), (0, 0, 255), 2)
198
+ cv2.putText(image_processed, f'{class_name} {confidence:.2f}%',
199
+ (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX,
200
+ 0.9, (0, 0, 255), 2)
201
+
202
+ # Convert back to PIL Image
203
+ processed_image = Image.fromarray(cv2.cvtColor(image_processed, cv2.COLOR_BGR2RGB))
204
+
205
+ # Prepare detection info
206
+ if bounding_boxes:
207
+ detection_info = "\n".join(
208
+ [f'{box["class_name"]}: {box["confidence"]:.2f}%' for box in bounding_boxes]
209
+ )
210
+ else:
211
+ detection_info = "No detections found."
212
+
213
+ # Prepare detection details as Markdown
214
+ if bounding_boxes:
215
+ details = ""
216
+ for idx, box in enumerate(bounding_boxes):
217
+ class_name = box['class_name']
218
+ confidence = box['confidence']
219
+ description = search_class_description(class_name)
220
+ img_url = search_class_image(class_name)
221
+ img_md = ""
222
+ if img_url:
223
+ try:
224
+ headers = {
225
+ 'User-Agent': 'MyApp/1.0 (https://example.com/contact; myemail@example.com)'
226
+ }
227
+ response = requests.get(img_url, headers=headers, timeout=10)
228
+ img_data = response.content
229
+ img = Image.open(BytesIO(img_data)).convert("RGB")
230
+ img.thumbnail((400, 400)) # Resize for faster loading
231
+ buffered = BytesIO()
232
+ img.save(buffered, format="PNG")
233
+ img_str = base64.b64encode(buffered.getvalue()).decode()
234
+ img_md = f"![{class_name}](data:image/png;base64,{img_str})\n\n"
235
+ except Exception as e:
236
+ print(f"Error fetching image for {class_name}: {e}")
237
+ details += f"### {idx+1}. {class_name} ({confidence:.2f}%)\n\n"
238
+ if description:
239
+ details += f"{description}\n\n"
240
+ if img_md:
241
+ details += f"{img_md}\n\n"
242
+ detection_details_md = details
243
+ else:
244
+ detection_details_md = "No detections to show."
245
+
246
+ return processed_image, detection_info, detection_details_md, bounding_boxes, original_image_pil
247
+ except Exception as e:
248
+ print(f"Error processing image: {e}")
249
+ return None, f"Error processing image: {e}", f"Error processing image: {e}", [], None
250
+
251
+ def ask_eurybia(question, state):
252
+ if not question.strip():
253
+ return "Please enter a valid question.", state
254
+
255
+ if not state['bounding_boxes']:
256
+ return "No detected objects to ask about.", state
257
+
258
+ # Combine descriptions of all detected objects as context
259
+ context = ""
260
+ for box in state['bounding_boxes']:
261
+ description = search_class_description(box['class_name'])
262
+ if description:
263
+ context += description + "\n"
264
+
265
+ if not context.strip():
266
+ return "No sufficient context available to answer the question.", state
267
+
268
+ try:
269
+ if models['qa_pipeline'] is None:
270
+ return "QA pipeline is not initialized.", state
271
+
272
+ answer = models['qa_pipeline'](question=question, context=context)
273
+ answer_text = answer['answer'].strip()
274
+ if not answer_text:
275
+ return "I couldn't find an answer to that question based on the detected objects.", state
276
+ return answer_text, state
277
+ except Exception as e:
278
+ print(f"Error during question answering: {e}")
279
+ return f"Error during question answering: {e}", state
280
+
281
+ def enhance_image(cropped_image_pil):
282
+ if models['upscaler'] is None:
283
+ return None, "Upscaling model is not initialized."
284
+
285
+ try:
286
+ input_image = cropped_image_pil.convert("RGB")
287
+ img = np.array(input_image)
288
+
289
+ # Run the model to enhance the image
290
+ output, _ = models['upscaler'].enhance(img, outscale=4)
291
+
292
+ enhanced_image = Image.fromarray(output)
293
+
294
+ return enhanced_image, "Image enhanced successfully."
295
+ except Exception as e:
296
+ print(f"Error during image enhancement: {e}")
297
+ return None, f"Error during image enhancement: {e}"
298
+
299
+ def run_depth_prediction(original_image):
300
+ if models['depth_pipe'] is None:
301
+ error_msg = models.get('depth_init_error', "Depth estimation model is not initialized.")
302
+ return None, error_msg
303
+
304
+ try:
305
+ if original_image is None:
306
+ return None, "No image uploaded for depth prediction."
307
+
308
+ # Prepare the image
309
+ input_image = original_image.convert("RGB")
310
+
311
+ # Run the depth pipeline
312
+ result = models['depth_pipe'](input_image)
313
+
314
+ # Access the depth prediction
315
+ depth_prediction = result.prediction # Adjust based on sample code
316
+
317
+ # Visualize the depth map
318
+ vis_depth = models['depth_pipe'].image_processor.visualize_depth(depth_prediction)
319
+
320
+ # Ensure vis_depth is a list and extract the first image
321
+ if isinstance(vis_depth, list) and len(vis_depth) > 0:
322
+ vis_depth_image = vis_depth[0]
323
+ else:
324
+ vis_depth_image = vis_depth # Fallback if not a list
325
+
326
+ return vis_depth_image, "Depth prediction completed."
327
+ except Exception as e:
328
+ print(f"Error during depth prediction: {e}")
329
+ return None, f"Error during depth prediction: {e}"
330
+
331
+ # Gradio Interface Components
332
+ with gr.Blocks() as demo:
333
+ gr.Markdown("# Eurybia Mini - Object Detection and Analysis Tool")
334
+
335
+ with gr.Tab("Upload & Process"):
336
+ with gr.Row():
337
+ with gr.Column():
338
+ image_input = gr.Image(type="pil", label="Upload Image")
339
+ process_button = gr.Button("Process Image")
340
+ clear_button = gr.Button("Clear")
341
+ with gr.Column():
342
+ processed_image = gr.Image(type="pil", label="Processed Image")
343
+ detection_info = gr.Textbox(label="Detection Information", lines=10)
344
+
345
+ with gr.Tab("Detection Details"):
346
+ with gr.Accordion("Click to see detection details", open=False):
347
+ detection_details_md = gr.Markdown("No detections to show.")
348
+
349
+ with gr.Tab("Ask Eurybia"):
350
+ with gr.Row():
351
+ with gr.Column():
352
+ question_input = gr.Textbox(label="Ask a question about the detected objects")
353
+ ask_button = gr.Button("Ask Eurybia")
354
+ with gr.Column():
355
+ answer_output = gr.Markdown(label="Eurybia's Answer")
356
+
357
+ with gr.Tab("Depth Estimation"):
358
+ with gr.Row():
359
+ with gr.Column():
360
+ depth_button = gr.Button("Run Depth Prediction")
361
+ with gr.Column():
362
+ depth_output = gr.Image(type="pil", label="Depth Map")
363
+ depth_status = gr.Textbox(label="Status", lines=2)
364
+
365
+ # Display error message if depth estimation model failed to initialize
366
+ if models.get('depth_init_error'):
367
+ gr.Markdown(f"**Depth Estimation Initialization Error:** {models['depth_init_error']}")
368
+
369
+ with gr.Tab("Enhance Detected Objects"):
370
+ if models['yolo_model'] is not None and models['upscaler'] is not None:
371
+ with gr.Row():
372
+ detected_objects = gr.Dropdown(choices=[], label="Select Detected Object", interactive=True)
373
+ enhance_btn = gr.Button("Enhance Image")
374
+ with gr.Column():
375
+ enhanced_image = gr.Image(type="pil", label="Enhanced Image")
376
+ enhance_status = gr.Textbox(label="Status", lines=2)
377
+ else:
378
+ gr.Markdown("**Warning:** YOLO model or Upscaling model is not initialized. Image enhancement functionality will be unavailable.")
379
+
380
+ with gr.Tab("Credits"):
381
+ gr.Markdown("""
382
+ # Credits and Licensing Information
383
+
384
+ This project utilizes various open-source libraries, tools, pretrained models, and datasets. Below is the list of components used and their respective credits/licenses:
385
+
386
+ ## Libraries
387
+ - **Python** - Python Software Foundation License (PSF License)
388
+ - **Gradio** - Licensed under the Apache License 2.0
389
+ - **Torch (PyTorch)** - Licensed under the BSD 3-Clause License
390
+ - **OpenCV (cv2)** - Licensed under the Apache License 2.0
391
+ - **NumPy** - Licensed under the BSD License
392
+ - **Pillow (PIL)** - Licensed under the HPND License
393
+ - **Requests** - Licensed under the Apache License 2.0
394
+ - **Wikipedia API** - Licensed under the MIT License
395
+ - **Transformers** - Licensed under the Apache License 2.0
396
+ - **Diffusers** - Licensed under the Apache License 2.0
397
+ - **Real-ESRGAN** - Licensed under the MIT License
398
+ - **BasicSR** - Licensed under the Apache License 2.0
399
+ - **Ultralytics YOLO** - Licensed under the GPL-3.0 License
400
+
401
+ ## Pretrained Models
402
+ - **deepset/roberta-base-squad2 (RoBERTa)** - Model provided by Hugging Face under the Apache License 2.0.
403
+ - **google/gemma-2-2b-it** - Model provided by Hugging Face under the Apache License 2.0.
404
+ - **prs-eth/marigold-depth-lcm-v1-0** - Licensed under the Apache License 2.0.
405
+ - **Real-ESRGAN model weights (RealESRGAN_x4plus.pth)** - Distributed under the MIT License.
406
+ - **FathomNet MBARI 315K YOLOv8 Model**:
407
+ - **Dataset**: Sourced from [FathomNet](https://fathomnet.org).
408
+ - **Model**: Derived from MBARI’s curated dataset of 315,000 marine annotations.
409
+ - **License**: Dataset and models adhere to MBARI’s Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).
410
+
411
+ ## Datasets
412
+ - **FathomNet MBARI Dataset**:
413
+ - A large-scale dataset for marine biodiversity image annotations.
414
+ - All content adheres to the [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/).
415
+
416
+ ## Acknowledgments
417
+ - **Ultralytics YOLO**: For the YOLOv8 architecture used for object detection.
418
+ - **FathomNet and MBARI**: For providing the marine dataset and annotations that support object detection in underwater imagery.
419
+ - **Gradio**: For providing an intuitive interface for machine learning applications.
420
+ - **Hugging Face**: For pretrained models and pipelines (e.g., Transformers, Diffusers).
421
+ - **Real-ESRGAN**: For image enhancement and upscaling models.
422
+ - **Wikipedia API**: For fetching object descriptions and images.
423
+ """)
424
+
425
+ # Hidden state to store bounding boxes, original and processed images
426
+ state = gr.State({"bounding_boxes": [], "last_image": None, "original_image": None})
427
+
428
+ # Event Handlers
429
+ def on_process_image(image, state):
430
+ processed_img, info, details, bounding_boxes, original_image_pil = process_image(image)
431
+ if processed_img is not None:
432
+ # Update the state with new bounding boxes and images
433
+ state['bounding_boxes'] = bounding_boxes
434
+ state['last_image'] = processed_img
435
+ state['original_image'] = original_image_pil
436
+ # Update the dropdown choices for detected objects
437
+ choices = [f"{idx+1}. {box['class_name']} ({box['confidence']:.2f}%)" for idx, box in enumerate(bounding_boxes)]
438
+ else:
439
+ choices = []
440
+ return processed_img, info, details, gr.update(choices=choices), state
441
+
442
+ process_button.click(
443
+ on_process_image,
444
+ inputs=[image_input, state],
445
+ outputs=[processed_image, detection_info, detection_details_md, detected_objects, state]
446
+ )
447
+
448
+ def on_clear(state):
449
+ state = {"bounding_boxes": [], "last_image": None, "original_image": None}
450
+ return None, "No detections found.", "No detections to show.", gr.update(choices=[]), state
451
+
452
+ clear_button.click(
453
+ on_clear,
454
+ inputs=state,
455
+ outputs=[processed_image, detection_info, detection_details_md, detected_objects, state]
456
+ )
457
+
458
+ def on_ask_eurybia(question, state):
459
+ answer, state = ask_eurybia(question, state)
460
+ return answer, state
461
+
462
+ ask_button.click(
463
+ on_ask_eurybia,
464
+ inputs=[question_input, state],
465
+ outputs=[answer_output, state]
466
+ )
467
+
468
+ def on_depth_prediction(state):
469
+ original_image = state.get('original_image')
470
+ depth_img, status = run_depth_prediction(original_image)
471
+ return depth_img, status
472
+
473
+ depth_button.click(
474
+ on_depth_prediction,
475
+ inputs=state,
476
+ outputs=[depth_output, depth_status]
477
+ )
478
+
479
+ def on_enhance_image(selected_object, state):
480
+ if not selected_object:
481
+ return None, "No object selected.", state
482
+
483
+ try:
484
+ idx = int(selected_object.split('.')[0]) - 1
485
+ box = state['bounding_boxes'][idx]
486
+ class_name = box['class_name']
487
+ x1, y1, x2, y2 = box['coords']
488
+
489
+ if not state.get('last_image'):
490
+ return None, "Processed image is not available.", state
491
+
492
+ # Ensure processed_image is stored in state
493
+ processed_img_pil = state['last_image']
494
+ if not isinstance(processed_img_pil, Image.Image):
495
+ return None, "Processed image is in an unsupported format.", state
496
+
497
+ # Convert processed_image to OpenCV format with checks
498
+ processed_img_cv = np.array(processed_img_pil)
499
+ if processed_img_cv.dtype != np.uint8:
500
+ processed_img_cv = processed_img_cv.astype(np.uint8)
501
+
502
+ if len(processed_img_cv.shape) != 3 or processed_img_cv.shape[2] != 3:
503
+ return None, "Invalid processed image format.", state
504
+
505
+ processed_img_cv = cv2.cvtColor(processed_img_cv, cv2.COLOR_RGB2BGR)
506
+
507
+ # Crop the detected object from the processed image
508
+ cropped_img_cv = processed_img_cv[y1:y2, x1:x2]
509
+ if cropped_img_cv.size == 0:
510
+ return None, "Cropped image is empty.", state
511
+
512
+ cropped_img_pil = Image.fromarray(cv2.cvtColor(cropped_img_cv, cv2.COLOR_BGR2RGB))
513
+
514
+ # Enhance the cropped image
515
+ enhanced_img, status = enhance_image(cropped_img_pil)
516
+ return enhanced_img, status, state
517
+ except Exception as e:
518
+ return None, f"Error: {e}", state
519
+
520
+ if models['yolo_model'] is not None and models['upscaler'] is not None:
521
+ enhance_btn.click(
522
+ on_enhance_image,
523
+ inputs=[detected_objects, state],
524
+ outputs=[enhanced_image, enhance_status, state]
525
+ )
526
+
527
+ # Optional: Add a note if the depth model isn't initialized
528
+ if models['depth_pipe'] is None and not models.get('depth_init_error'):
529
+ gr.Markdown("**Warning:** Depth estimation model is not initialized. Depth prediction functionality will be unavailable.")
530
+
531
+ # Optional: Add a note if the upscaler isn't initialized
532
+ if models['upscaler'] is None:
533
+ gr.Markdown("**Warning:** Upscaling model is not initialized. Image enhancement functionality will be unavailable.")
534
+
535
+ # Launch the Gradio app
536
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Libraries
2
+ numpy==1.23.5
3
+ opencv-python==4.8.0.74
4
+ Pillow==10.0.0
5
+ requests==2.31.0
6
+ wikipedia==1.4.0
7
+
8
+ # PyTorch
9
+ torch==2.0.1
10
+
11
+ # Hugging Face Ecosystem
12
+ transformers==4.31.0
13
+ huggingface-hub==0.14.1
14
+ diffusers==0.19.3
15
+ accelerate==0.20.3
16
+
17
+ # Real-ESRGAN and Dependencies
18
+ realesrgan==0.3.0
19
+ basicsr==1.4.2
20
+
21
+ # Ultralytics (YOLO)
22
+ ultralytics==8.0.120
23
+
24
+ # Gradio
25
+ gradio==3.40.0
26
+
27
+ # Additional Packages (Ensure Compatibility)
28
+ datasets==2.8.0 # Example version; adjust as needed
29
+ protobuf==3.20.3 # Compatible with <4
30
+ click==8.0.4 # Compatible with <8.1
31
+ pydantic==1.10.7 # Compatible with ~=1.0