zero-shot-seg / app.py
danieaneta's picture
Update app.py
2010bf2 verified
raw
history blame
1.98 kB
import gradio as gr
from PIL import Image
import base64
import io
import numpy as np
from typing import List
from main import segmenter # Import the segmenter instance
def process_image(image: Image.Image, objects_text: str) -> dict:
"""Process image and return results"""
try:
# Parse objects
objects = [obj.strip() for obj in objects_text.split('.') if obj.strip()]
# Use the segmenter to process the image
results = segmenter.segment_objects(image, objects)
# Create visualization of results
# For now, just returning the original image
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Format results for response
return {
"success": True,
"message": f"Processed image with objects: {objects}",
"image": img_str,
"results": [
{
"label": r.label,
"confidence": float(r.confidence),
"bounding_box": r.bounding_box
}
for r in results
]
}
except Exception as e:
return {
"success": False,
"message": str(e),
"image": None,
"results": []
}
# Create Gradio interface with API mode enabled
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Objects (separate with dots)", placeholder="cat. dog. chair")
],
outputs=gr.JSON(label="API Response"),
title="Zero Shot Segmentation",
description="Upload an image and specify objects to detect.",
allow_flagging="never"
)
# Enable API access
demo.queue()
if __name__ == "__main__":
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_api=True
)