Spaces:
Runtime error
Runtime error
| import os | |
| # Path to the degradations.py file | |
| degradations_path = '/usr/local/lib/python3.10/site-packages/basicsr/data/degradations.py' | |
| # Check if the file exists | |
| if os.path.exists(degradations_path): | |
| with open(degradations_path, 'r') as file: | |
| content = file.read() | |
| # Replace the problematic import | |
| content = content.replace( | |
| 'from torchvision.transforms.functional_tensor import rgb_to_grayscale', | |
| 'from torchvision.transforms import functional as F\nrgb_to_grayscale = F.rgb_to_grayscale' | |
| ) | |
| # Write the modified content back | |
| with open(degradations_path, 'w') as file: | |
| file.write(content) | |
| else: | |
| print("degradations.py not found!") | |
| import glob | |
| import time | |
| import threading | |
| import requests | |
| import wikipedia | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from io import BytesIO | |
| from PIL import Image | |
| import base64 # Added import | |
| import gradio as gr | |
| from ultralytics import YOLO | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from diffusers import MarigoldDepthPipeline # Updated import for depth model | |
| from realesrgan import RealESRGANer | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| # Set environment variable for PyTorch MPS fallback before importing torch | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' | |
| # Initialize Models | |
| def initialize_models(): | |
| models = {} | |
| # Device detection | |
| if torch.cuda.is_available(): | |
| device = 'cuda' | |
| elif torch.backends.mps.is_available(): | |
| device = 'mps' | |
| else: | |
| device = 'cpu' | |
| models['device'] = device | |
| print(f"Using device: {device}") | |
| # Initialize the RoBERTa model for question answering | |
| try: | |
| models['qa_pipeline'] = pipeline( | |
| "question-answering", model="deepset/roberta-base-squad2", device=0 if device == 'cuda' else -1) | |
| print("RoBERTa QA pipeline initialized.") | |
| except Exception as e: | |
| print(f"Error initializing the RoBERTa model: {e}") | |
| models['qa_pipeline'] = None | |
| # Initialize the Gemma model | |
| try: | |
| models['gemma_tokenizer'] = AutoTokenizer.from_pretrained("google/gemma-2-2b-it") | |
| models['gemma_model'] = AutoModelForCausalLM.from_pretrained( | |
| "google/gemma-2-2b-it", | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32 | |
| ) | |
| print("Gemma model initialized.") | |
| except Exception as e: | |
| print(f"Error initializing the Gemma model: {e}") | |
| models['gemma_model'] = None | |
| # Initialize the depth estimation model using MarigoldDepthPipeline exactly as per your sample | |
| try: | |
| if device == 'cuda': | |
| models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained( | |
| "prs-eth/marigold-depth-lcm-v1-0", | |
| variant="fp16", | |
| torch_dtype=torch.float16 | |
| ).to('cuda') | |
| else: | |
| # For CPU or MPS devices, keep on 'cpu' to avoid unsupported operators | |
| models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained( | |
| "prs-eth/marigold-depth-lcm-v1-0", | |
| torch_dtype=torch.float32 | |
| ).to('cpu') | |
| print("Depth estimation model initialized.") | |
| except Exception as e: | |
| error_message = f"Error initializing the depth estimation model: {e}" | |
| print(error_message) | |
| models['depth_pipe'] = None | |
| models['depth_init_error'] = error_message # Store the error message | |
| # Initialize the upscaling model | |
| try: | |
| upscaler_model_path = 'weights/RealESRGAN_x4plus.pth' # Ensure this path is correct | |
| if not os.path.exists(upscaler_model_path): | |
| print(f"Upscaling model weights not found at {upscaler_model_path}. Please download them.") | |
| models['upscaler'] = None | |
| else: | |
| # Define the model architecture | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, | |
| num_block=23, num_grow_ch=32, scale=4) | |
| # Initialize RealESRGANer | |
| models['upscaler'] = RealESRGANer( | |
| scale=4, | |
| model_path=upscaler_model_path, | |
| model=model, | |
| pre_pad=0, | |
| half=(device == 'cuda'), | |
| device=device | |
| ) | |
| print("Real-ESRGAN upscaler initialized.") | |
| except Exception as e: | |
| print(f"Error initializing the upscaling model: {e}") | |
| models['upscaler'] = None | |
| # Initialize YOLO model | |
| try: | |
| source_weights_path = "/Users/David/Downloads/WheelOfFortuneLab-DavidDriscoll/Eurybia1.3/mbari_315k_yolov8.pt" | |
| if not os.path.exists(source_weights_path): | |
| print(f"YOLO weights not found at {source_weights_path}. Please download them.") | |
| models['yolo_model'] = None | |
| else: | |
| models['yolo_model'] = YOLO(source_weights_path) | |
| print("YOLO model initialized.") | |
| except Exception as e: | |
| print(f"Error initializing YOLO model: {e}") | |
| models['yolo_model'] = None | |
| return models | |
| models = initialize_models() | |
| # Utility Functions | |
| def search_class_description(class_name): | |
| wikipedia.set_lang("en") | |
| wikipedia.set_rate_limiting(True) | |
| description = "" | |
| try: | |
| page = wikipedia.page(class_name) | |
| if page: | |
| description = page.content[:5000] # Get more content | |
| except Exception as e: | |
| print(f"Error fetching description for {class_name}: {e}") | |
| return description | |
| def search_class_image(class_name): | |
| wikipedia.set_lang("en") | |
| wikipedia.set_rate_limiting(True) | |
| img_url = "" | |
| try: | |
| page = wikipedia.page(class_name) | |
| if page: | |
| for img in page.images: | |
| if img.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')): | |
| img_url = img | |
| break | |
| except Exception as e: | |
| print(f"Error fetching image for {class_name}: {e}") | |
| return img_url | |
| def process_image(image): | |
| if models['yolo_model'] is None: | |
| return None, "YOLO model is not initialized.", "YOLO model is not initialized.", [], None | |
| try: | |
| if image is None: | |
| return None, "No image uploaded.", "No image uploaded.", [], None | |
| # Convert Gradio Image to OpenCV format | |
| image_np = np.array(image) | |
| if image_np.dtype != np.uint8: | |
| image_np = image_np.astype(np.uint8) | |
| if len(image_np.shape) != 3 or image_np.shape[2] != 3: | |
| return None, "Invalid image format. Please upload a RGB image.", "Invalid image format. Please upload a RGB image.", [], None | |
| image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| # Store the original image before drawing bounding boxes | |
| original_image_cv = image_cv.copy() | |
| original_image_pil = Image.fromarray(cv2.cvtColor(original_image_cv, cv2.COLOR_BGR2RGB)) | |
| # Perform YOLO prediction | |
| results = models['yolo_model'].predict( | |
| source=image_cv, conf=0.075)[0] # Lowered the threshold | |
| bounding_boxes = [] | |
| image_processed = image_cv.copy() | |
| if results.boxes is not None: | |
| for box in results.boxes: | |
| x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
| class_name = models['yolo_model'].names[int(box.cls)] | |
| confidence = box.conf.item() * 100 # Convert to percentage | |
| bounding_boxes.append({ | |
| "coords": (x1, y1, x2, y2), | |
| "class_name": class_name, | |
| "confidence": confidence | |
| }) | |
| cv2.rectangle(image_processed, (x1, y1), (x2, y2), (0, 0, 255), 2) | |
| cv2.putText(image_processed, f'{class_name} {confidence:.2f}%', | |
| (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.9, (0, 0, 255), 2) | |
| # Convert back to PIL Image | |
| processed_image = Image.fromarray(cv2.cvtColor(image_processed, cv2.COLOR_BGR2RGB)) | |
| # Prepare detection info | |
| if bounding_boxes: | |
| detection_info = "\n".join( | |
| [f'{box["class_name"]}: {box["confidence"]:.2f}%' for box in bounding_boxes] | |
| ) | |
| else: | |
| detection_info = "No detections found." | |
| # Prepare detection details as Markdown | |
| if bounding_boxes: | |
| details = "" | |
| for idx, box in enumerate(bounding_boxes): | |
| class_name = box['class_name'] | |
| confidence = box['confidence'] | |
| description = search_class_description(class_name) | |
| img_url = search_class_image(class_name) | |
| img_md = "" | |
| if img_url: | |
| try: | |
| headers = { | |
| 'User-Agent': 'MyApp/1.0 (https://example.com/contact; [email protected])' | |
| } | |
| response = requests.get(img_url, headers=headers, timeout=10) | |
| img_data = response.content | |
| img = Image.open(BytesIO(img_data)).convert("RGB") | |
| img.thumbnail((400, 400)) # Resize for faster loading | |
| buffered = BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| img_md = f"\n\n" | |
| except Exception as e: | |
| print(f"Error fetching image for {class_name}: {e}") | |
| details += f"### {idx+1}. {class_name} ({confidence:.2f}%)\n\n" | |
| if description: | |
| details += f"{description}\n\n" | |
| if img_md: | |
| details += f"{img_md}\n\n" | |
| detection_details_md = details | |
| else: | |
| detection_details_md = "No detections to show." | |
| return processed_image, detection_info, detection_details_md, bounding_boxes, original_image_pil | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| return None, f"Error processing image: {e}", f"Error processing image: {e}", [], None | |
| def ask_eurybia(question, state): | |
| if not question.strip(): | |
| return "Please enter a valid question.", state | |
| if not state['bounding_boxes']: | |
| return "No detected objects to ask about.", state | |
| # Combine descriptions of all detected objects as context | |
| context = "" | |
| for box in state['bounding_boxes']: | |
| description = search_class_description(box['class_name']) | |
| if description: | |
| context += description + "\n" | |
| if not context.strip(): | |
| return "No sufficient context available to answer the question.", state | |
| try: | |
| if models['qa_pipeline'] is None: | |
| return "QA pipeline is not initialized.", state | |
| answer = models['qa_pipeline'](question=question, context=context) | |
| answer_text = answer['answer'].strip() | |
| if not answer_text: | |
| return "I couldn't find an answer to that question based on the detected objects.", state | |
| return answer_text, state | |
| except Exception as e: | |
| print(f"Error during question answering: {e}") | |
| return f"Error during question answering: {e}", state | |
| def enhance_image(cropped_image_pil): | |
| if models['upscaler'] is None: | |
| return None, "Upscaling model is not initialized." | |
| try: | |
| input_image = cropped_image_pil.convert("RGB") | |
| img = np.array(input_image) | |
| # Run the model to enhance the image | |
| output, _ = models['upscaler'].enhance(img, outscale=4) | |
| enhanced_image = Image.fromarray(output) | |
| return enhanced_image, "Image enhanced successfully." | |
| except Exception as e: | |
| print(f"Error during image enhancement: {e}") | |
| return None, f"Error during image enhancement: {e}" | |
| def run_depth_prediction(original_image): | |
| if models['depth_pipe'] is None: | |
| error_msg = models.get('depth_init_error', "Depth estimation model is not initialized.") | |
| return None, error_msg | |
| try: | |
| if original_image is None: | |
| return None, "No image uploaded for depth prediction." | |
| # Prepare the image | |
| input_image = original_image.convert("RGB") | |
| # Run the depth pipeline | |
| result = models['depth_pipe'](input_image) | |
| # Access the depth prediction | |
| depth_prediction = result.prediction # Adjust based on sample code | |
| # Visualize the depth map | |
| vis_depth = models['depth_pipe'].image_processor.visualize_depth(depth_prediction) | |
| # Ensure vis_depth is a list and extract the first image | |
| if isinstance(vis_depth, list) and len(vis_depth) > 0: | |
| vis_depth_image = vis_depth[0] | |
| else: | |
| vis_depth_image = vis_depth # Fallback if not a list | |
| return vis_depth_image, "Depth prediction completed." | |
| except Exception as e: | |
| print(f"Error during depth prediction: {e}") | |
| return None, f"Error during depth prediction: {e}" | |
| # Gradio Interface Components | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Eurybia Mini - Object Detection and Analysis Tool") | |
| with gr.Tab("Upload & Process"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| process_button = gr.Button("Process Image") | |
| clear_button = gr.Button("Clear") | |
| with gr.Column(): | |
| processed_image = gr.Image(type="pil", label="Processed Image") | |
| detection_info = gr.Textbox(label="Detection Information", lines=10) | |
| with gr.Tab("Detection Details"): | |
| with gr.Accordion("Click to see detection details", open=False): | |
| detection_details_md = gr.Markdown("No detections to show.") | |
| with gr.Tab("Ask Eurybia"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| question_input = gr.Textbox(label="Ask a question about the detected objects") | |
| ask_button = gr.Button("Ask Eurybia") | |
| with gr.Column(): | |
| answer_output = gr.Markdown(label="Eurybia's Answer") | |
| with gr.Tab("Depth Estimation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| depth_button = gr.Button("Run Depth Prediction") | |
| with gr.Column(): | |
| depth_output = gr.Image(type="pil", label="Depth Map") | |
| depth_status = gr.Textbox(label="Status", lines=2) | |
| # Display error message if depth estimation model failed to initialize | |
| if models.get('depth_init_error'): | |
| gr.Markdown(f"**Depth Estimation Initialization Error:** {models['depth_init_error']}") | |
| with gr.Tab("Enhance Detected Objects"): | |
| if models['yolo_model'] is not None and models['upscaler'] is not None: | |
| with gr.Row(): | |
| detected_objects = gr.Dropdown(choices=[], label="Select Detected Object", interactive=True) | |
| enhance_btn = gr.Button("Enhance Image") | |
| with gr.Column(): | |
| enhanced_image = gr.Image(type="pil", label="Enhanced Image") | |
| enhance_status = gr.Textbox(label="Status", lines=2) | |
| else: | |
| gr.Markdown("**Warning:** YOLO model or Upscaling model is not initialized. Image enhancement functionality will be unavailable.") | |
| with gr.Tab("Credits"): | |
| gr.Markdown(""" | |
| # Credits and Licensing Information | |
| This project utilizes various open-source libraries, tools, pretrained models, and datasets. Below is the list of components used and their respective credits/licenses: | |
| ## Libraries | |
| - **Python** - Python Software Foundation License (PSF License) | |
| - **Gradio** - Licensed under the Apache License 2.0 | |
| - **Torch (PyTorch)** - Licensed under the BSD 3-Clause License | |
| - **OpenCV (cv2)** - Licensed under the Apache License 2.0 | |
| - **NumPy** - Licensed under the BSD License | |
| - **Pillow (PIL)** - Licensed under the HPND License | |
| - **Requests** - Licensed under the Apache License 2.0 | |
| - **Wikipedia API** - Licensed under the MIT License | |
| - **Transformers** - Licensed under the Apache License 2.0 | |
| - **Diffusers** - Licensed under the Apache License 2.0 | |
| - **Real-ESRGAN** - Licensed under the MIT License | |
| - **BasicSR** - Licensed under the Apache License 2.0 | |
| - **Ultralytics YOLO** - Licensed under the GPL-3.0 License | |
| ## Pretrained Models | |
| - **deepset/roberta-base-squad2 (RoBERTa)** - Model provided by Hugging Face under the Apache License 2.0. | |
| - **google/gemma-2-2b-it** - Model provided by Hugging Face under the Apache License 2.0. | |
| - **prs-eth/marigold-depth-lcm-v1-0** - Licensed under the Apache License 2.0. | |
| - **Real-ESRGAN model weights (RealESRGAN_x4plus.pth)** - Distributed under the MIT License. | |
| - **FathomNet MBARI 315K YOLOv8 Model**: | |
| - **Dataset**: Sourced from [FathomNet](https://fathomnet.org). | |
| - **Model**: Derived from MBARIβs curated dataset of 315,000 marine annotations. | |
| - **License**: Dataset and models adhere to MBARIβs Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0). | |
| ## Datasets | |
| - **FathomNet MBARI Dataset**: | |
| - A large-scale dataset for marine biodiversity image annotations. | |
| - All content adheres to the [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/). | |
| ## Acknowledgments | |
| - **Ultralytics YOLO**: For the YOLOv8 architecture used for object detection. | |
| - **FathomNet and MBARI**: For providing the marine dataset and annotations that support object detection in underwater imagery. | |
| - **Gradio**: For providing an intuitive interface for machine learning applications. | |
| - **Hugging Face**: For pretrained models and pipelines (e.g., Transformers, Diffusers). | |
| - **Real-ESRGAN**: For image enhancement and upscaling models. | |
| - **Wikipedia API**: For fetching object descriptions and images. | |
| """) | |
| # Hidden state to store bounding boxes, original and processed images | |
| state = gr.State({"bounding_boxes": [], "last_image": None, "original_image": None}) | |
| # Event Handlers | |
| def on_process_image(image, state): | |
| processed_img, info, details, bounding_boxes, original_image_pil = process_image(image) | |
| if processed_img is not None: | |
| # Update the state with new bounding boxes and images | |
| state['bounding_boxes'] = bounding_boxes | |
| state['last_image'] = processed_img | |
| state['original_image'] = original_image_pil | |
| # Update the dropdown choices for detected objects | |
| choices = [f"{idx+1}. {box['class_name']} ({box['confidence']:.2f}%)" for idx, box in enumerate(bounding_boxes)] | |
| else: | |
| choices = [] | |
| return processed_img, info, details, gr.update(choices=choices), state | |
| process_button.click( | |
| on_process_image, | |
| inputs=[image_input, state], | |
| outputs=[processed_image, detection_info, detection_details_md, detected_objects, state] | |
| ) | |
| def on_clear(state): | |
| state = {"bounding_boxes": [], "last_image": None, "original_image": None} | |
| return None, "No detections found.", "No detections to show.", gr.update(choices=[]), state | |
| clear_button.click( | |
| on_clear, | |
| inputs=state, | |
| outputs=[processed_image, detection_info, detection_details_md, detected_objects, state] | |
| ) | |
| def on_ask_eurybia(question, state): | |
| answer, state = ask_eurybia(question, state) | |
| return answer, state | |
| ask_button.click( | |
| on_ask_eurybia, | |
| inputs=[question_input, state], | |
| outputs=[answer_output, state] | |
| ) | |
| def on_depth_prediction(state): | |
| original_image = state.get('original_image') | |
| depth_img, status = run_depth_prediction(original_image) | |
| return depth_img, status | |
| depth_button.click( | |
| on_depth_prediction, | |
| inputs=state, | |
| outputs=[depth_output, depth_status] | |
| ) | |
| def on_enhance_image(selected_object, state): | |
| if not selected_object: | |
| return None, "No object selected.", state | |
| try: | |
| idx = int(selected_object.split('.')[0]) - 1 | |
| box = state['bounding_boxes'][idx] | |
| class_name = box['class_name'] | |
| x1, y1, x2, y2 = box['coords'] | |
| if not state.get('last_image'): | |
| return None, "Processed image is not available.", state | |
| # Ensure processed_image is stored in state | |
| processed_img_pil = state['last_image'] | |
| if not isinstance(processed_img_pil, Image.Image): | |
| return None, "Processed image is in an unsupported format.", state | |
| # Convert processed_image to OpenCV format with checks | |
| processed_img_cv = np.array(processed_img_pil) | |
| if processed_img_cv.dtype != np.uint8: | |
| processed_img_cv = processed_img_cv.astype(np.uint8) | |
| if len(processed_img_cv.shape) != 3 or processed_img_cv.shape[2] != 3: | |
| return None, "Invalid processed image format.", state | |
| processed_img_cv = cv2.cvtColor(processed_img_cv, cv2.COLOR_RGB2BGR) | |
| # Crop the detected object from the processed image | |
| cropped_img_cv = processed_img_cv[y1:y2, x1:x2] | |
| if cropped_img_cv.size == 0: | |
| return None, "Cropped image is empty.", state | |
| cropped_img_pil = Image.fromarray(cv2.cvtColor(cropped_img_cv, cv2.COLOR_BGR2RGB)) | |
| # Enhance the cropped image | |
| enhanced_img, status = enhance_image(cropped_img_pil) | |
| return enhanced_img, status, state | |
| except Exception as e: | |
| return None, f"Error: {e}", state | |
| if models['yolo_model'] is not None and models['upscaler'] is not None: | |
| enhance_btn.click( | |
| on_enhance_image, | |
| inputs=[detected_objects, state], | |
| outputs=[enhanced_image, enhance_status, state] | |
| ) | |
| # Optional: Add a note if the depth model isn't initialized | |
| if models['depth_pipe'] is None and not models.get('depth_init_error'): | |
| gr.Markdown("**Warning:** Depth estimation model is not initialized. Depth prediction functionality will be unavailable.") | |
| # Optional: Add a note if the upscaler isn't initialized | |
| if models['upscaler'] is None: | |
| gr.Markdown("**Warning:** Upscaling model is not initialized. Image enhancement functionality will be unavailable.") | |
| # Launch the Gradio app | |
| demo.launch() | |