File size: 5,052 Bytes
d66422e
 
 
68e425a
d66422e
d0ee1ff
d66422e
d0ee1ff
 
 
 
 
d66422e
 
 
68e425a
d66422e
 
68e425a
d66422e
 
68e425a
d66422e
 
68e425a
 
d66422e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68e425a
d66422e
68e425a
d66422e
68e425a
 
 
d66422e
68e425a
 
 
 
 
 
d66422e
68e425a
 
 
d66422e
68e425a
 
 
d66422e
68e425a
 
 
d66422e
68e425a
d66422e
68e425a
 
d66422e
68e425a
d66422e
68e425a
 
d66422e
68e425a
d66422e
68e425a
d66422e
68e425a
 
 
d66422e
68e425a
 
d66422e
68e425a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import re
import base64
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
from PIL import Image
from transformers import Owlv2Processor, Owlv2ForObjectDetection

model_id = "IDEA-Research/grounding-dino-base"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)

def input_image_setup(uploaded_file):
    """
    Encodes the uploaded image file into a base64 string.

    Parameters:
    - uploaded_file: File-like object uploaded via Gradio.

    Returns:
    - encoded_image (str): Base64 encoded string of the image data.
    """
    if uploaded_file is not None:
        # Convert the image to bytes and encode in Base64
        bytes_data = uploaded_file.tobytes()
        encoded_image = base64.b64encode(bytes_data).decode("utf-8")
        return encoded_image
    else:
        raise FileNotFoundError("No file uploaded")

def format_response(response_text):
    """
    Formats the model response to display each item on a new line as a list.
    Converts numbered items into HTML `<ul>` and `<li>` format.
    """
    response_text = re.sub(r"\*\*(.*?)\*\*", r"<p><strong>\1</strong></p>", response_text)
    response_text = re.sub(r"(?m)^\s*\*\s(.*)", r"<li>\1</li>", response_text)
    response_text = re.sub(r"(<li>.*?</li>)+", lambda match: f"<ul>{match.group(0)}</ul>", response_text, flags=re.DOTALL)
    response_text = re.sub(r"</p>(?=<p>)", r"</p><br>", response_text)
    response_text = re.sub(r"(\n|\\n)+", r"<br>", response_text)
    return response_text

def generate_model_response(uploaded_file, user_query):
    """
    Processes the uploaded image and user query to generate a response from the model.

    Parameters:
    - uploaded_file: The uploaded image file.
    - user_query: The user's question about the image.

    Returns:
    - str: The generated response from the model.
    """
    
    # Encode the uploaded image into Base64 format
    encoded_image = input_image_setup(uploaded_file)

    # Define the assistant prompt
    assistant_prompt = """
        You are an expert nutritionist. Your task is to analyze the food items displayed in the image and provide a detailed nutritional assessment using the following format:

        1. **Identification**: List each identified food item clearly, one per line.
        2. **Portion Size & Calorie Estimation**: For each identified food item, specify the portion size and provide an estimated number of calories. Use bullet points with the following structure:
        - **[Food Item]**: [Portion Size], [Number of Calories] calories

        Example:
        *   **Salmon**: 6 ounces, 210 calories
        *   **Asparagus**: 3 spears, 25 calories

        3. **Total Calories**: Provide the total number of calories for all food items.

        Example:
        Total Calories: [Number of Calories]

        4. **Nutrient Breakdown**: Include a breakdown of key nutrients such as **Protein**, **Carbohydrates**, **Fats**, **Vitamins**, and **Minerals**. Use bullet points, and for each nutrient provide details about the contribution of each food item.

        Example:
        *   **Protein**: Salmon (35g), Asparagus (3g), Tomatoes (1g) = [Total Protein]

        5. **Health Evaluation**: Evaluate the healthiness of the meal in one paragraph.

        6. **Disclaimer**: Include the following exact text as a disclaimer:

        The nutritional information and calorie estimates provided are approximate and are based on general food data. 
        Actual values may vary depending on factors such as portion size, specific ingredients, preparation methods, and individual variations. 
        For precise dietary advice or medical guidance, consult a qualified nutritionist or healthcare provider.

        Format your response exactly like the template above to ensure consistency.
    """

    # Prepare input for the model
    input_text = assistant_prompt + "\n\n" + user_query + "\n![Image](data:image/jpeg;base64," + encoded_image + ")"
    
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    
    try:
        # Generate response from the model
        outputs = model.generate(**inputs)
        
        # Decode and format the model's raw response
        raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        formatted_response = format_response(raw_response)
        
        return formatted_response
    
    except Exception as e:
        print(f"Error in generating response: {e}")
        return "An error occurred while generating the response."

# Create Gradio interface
iface = gr.Interface(
    fn=generate_model_response,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),  # Image upload component
        gr.Textbox(label="User Query", placeholder="Enter your question about the image...")
    ],
    outputs="html",  # Display formatted HTML output
)

# Launch Gradio app
iface.launch()