Spaces:
Runtime error
Runtime error
import re | |
import base64 | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
from PIL import Image | |
from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
model_id = "IDEA-Research/grounding-dino-base" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) | |
def input_image_setup(uploaded_file): | |
""" | |
Encodes the uploaded image file into a base64 string. | |
Parameters: | |
- uploaded_file: File-like object uploaded via Gradio. | |
Returns: | |
- encoded_image (str): Base64 encoded string of the image data. | |
""" | |
if uploaded_file is not None: | |
# Convert the image to bytes and encode in Base64 | |
bytes_data = uploaded_file.tobytes() | |
encoded_image = base64.b64encode(bytes_data).decode("utf-8") | |
return encoded_image | |
else: | |
raise FileNotFoundError("No file uploaded") | |
def format_response(response_text): | |
""" | |
Formats the model response to display each item on a new line as a list. | |
Converts numbered items into HTML `<ul>` and `<li>` format. | |
""" | |
response_text = re.sub(r"\*\*(.*?)\*\*", r"<p><strong>\1</strong></p>", response_text) | |
response_text = re.sub(r"(?m)^\s*\*\s(.*)", r"<li>\1</li>", response_text) | |
response_text = re.sub(r"(<li>.*?</li>)+", lambda match: f"<ul>{match.group(0)}</ul>", response_text, flags=re.DOTALL) | |
response_text = re.sub(r"</p>(?=<p>)", r"</p><br>", response_text) | |
response_text = re.sub(r"(\n|\\n)+", r"<br>", response_text) | |
return response_text | |
def generate_model_response(uploaded_file, user_query): | |
""" | |
Processes the uploaded image and user query to generate a response from the model. | |
Parameters: | |
- uploaded_file: The uploaded image file. | |
- user_query: The user's question about the image. | |
Returns: | |
- str: The generated response from the model. | |
""" | |
# Encode the uploaded image into Base64 format | |
encoded_image = input_image_setup(uploaded_file) | |
# Define the assistant prompt | |
assistant_prompt = """ | |
You are an expert nutritionist. Your task is to analyze the food items displayed in the image and provide a detailed nutritional assessment using the following format: | |
1. **Identification**: List each identified food item clearly, one per line. | |
2. **Portion Size & Calorie Estimation**: For each identified food item, specify the portion size and provide an estimated number of calories. Use bullet points with the following structure: | |
- **[Food Item]**: [Portion Size], [Number of Calories] calories | |
Example: | |
* **Salmon**: 6 ounces, 210 calories | |
* **Asparagus**: 3 spears, 25 calories | |
3. **Total Calories**: Provide the total number of calories for all food items. | |
Example: | |
Total Calories: [Number of Calories] | |
4. **Nutrient Breakdown**: Include a breakdown of key nutrients such as **Protein**, **Carbohydrates**, **Fats**, **Vitamins**, and **Minerals**. Use bullet points, and for each nutrient provide details about the contribution of each food item. | |
Example: | |
* **Protein**: Salmon (35g), Asparagus (3g), Tomatoes (1g) = [Total Protein] | |
5. **Health Evaluation**: Evaluate the healthiness of the meal in one paragraph. | |
6. **Disclaimer**: Include the following exact text as a disclaimer: | |
The nutritional information and calorie estimates provided are approximate and are based on general food data. | |
Actual values may vary depending on factors such as portion size, specific ingredients, preparation methods, and individual variations. | |
For precise dietary advice or medical guidance, consult a qualified nutritionist or healthcare provider. | |
Format your response exactly like the template above to ensure consistency. | |
""" | |
# Prepare input for the model | |
input_text = assistant_prompt + "\n\n" + user_query + "\n" | |
# Tokenize input text | |
inputs = tokenizer(input_text, return_tensors="pt") | |
try: | |
# Generate response from the model | |
outputs = model.generate(**inputs) | |
# Decode and format the model's raw response | |
raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
formatted_response = format_response(raw_response) | |
return formatted_response | |
except Exception as e: | |
print(f"Error in generating response: {e}") | |
return "An error occurred while generating the response." | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_model_response, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), # Image upload component | |
gr.Textbox(label="User Query", placeholder="Enter your question about the image...") | |
], | |
outputs="html", # Display formatted HTML output | |
) | |
# Launch Gradio app | |
iface.launch() | |