Spaces:
Runtime error
Runtime error
import re | |
import base64 | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
from PIL import Image | |
import io | |
from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14-finetuned") | |
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-large-patch14-finetuned") | |
def input_image_setup(image_file): | |
""" | |
Encodes the uploaded image file into a base64 string. | |
Parameters: | |
- image_file: Image file uploaded via Gradio | |
Returns: | |
- encoded_image (str): Base64 encoded string of the image data | |
""" | |
if image_file is not None: | |
# Convert the PIL Image object to bytes and encode in Base64 | |
buffered = io.BytesIO() | |
image_file.save(buffered, format="JPEG") | |
img_bytes = buffered.getvalue() | |
encoded_image = base64.b64encode(img_bytes).decode("utf-8") | |
return encoded_image | |
else: | |
raise FileNotFoundError("No file uploaded") | |
def format_response(response_text): | |
""" | |
Formats the model response to display each item as HTML elements. | |
""" | |
response_text = re.sub(r"\*\*(.*?)\*\*", r"<p><strong>\1</strong></p>", response_text) | |
response_text = re.sub(r"(?m)^\s*\*\s(.*)", r"<li>\1</li>", response_text) | |
response_text = re.sub(r"(<li>.*?</li>)+", lambda match: f"<ul>{match.group(0)}</ul>", response_text, flags=re.DOTALL) | |
response_text = re.sub(r"</p>(?=<p>)", r"</p><br>", response_text) | |
response_text = re.sub(r"(\n|\\n)+", r"<br>", response_text) | |
return response_text | |
def generate_model_response(image_file, user_query): | |
""" | |
Processes the uploaded image and user query to generate a response from the model. | |
Parameters: | |
- image_file: The uploaded image file. | |
- user_query: The user's question about the image. | |
Returns: | |
- str: The generated response from the model, formatted as HTML. | |
""" | |
try: | |
encoded_image = input_image_setup(image_file) | |
except FileNotFoundError as e: | |
return f"<p>{str(e)}</p>" | |
assistant_prompt = """ | |
You are an expert nutritionist. Analyze the food items in the image and provide a detailed nutritional assessment: | |
1. **Identification**: List each food item. | |
2. **Portion & Calories**: Specify portion size and calories for each item. | |
3. **Total Calories**: Provide the total. | |
4. **Nutrient Breakdown**: Detail key nutrients. | |
5. **Health Evaluation**: Evaluate meal healthiness. | |
6. **Disclaimer**: "Nutritional info is approximate. Consult a nutritionist for precise advice." | |
Format your response accordingly. | |
""" | |
input_text = assistant_prompt + "\n\n" + user_query + "\n" | |
inputs = tokenizer(input_text, return_tensors="pt") | |
try: | |
outputs = model.generate(**inputs) | |
raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
formatted_response = format_response(raw_response) | |
return formatted_response | |
except Exception as e: | |
print(f"Error in generating response: {e}") | |
return f"<p>An error occurred: {str(e)}</p>" | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=generate_model_response, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Textbox(label="Enter your question", placeholder="How many calories are in this food?") | |
], | |
outputs=gr.HTML(label="Nutritional Assessment") | |
) | |
iface.launch(true) |