Kilos1 commited on
Commit
ec744ec
·
verified ·
1 Parent(s): 1512baa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -66
app.py CHANGED
@@ -1,49 +1,24 @@
1
  import re
2
  import base64
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import gradio as gr
5
  from PIL import Image
6
- import io
7
- from transformers import Owlv2Processor, Owlv2ForObjectDetection
8
-
9
- processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14-finetuned")
10
- model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-large-patch14-finetuned")
11
-
12
- def input_image_setup(image_file):
13
- """
14
- Encodes the uploaded image file into a base64 string.
15
-
16
- Parameters:
17
- - image_file: Image file uploaded via Gradio
18
-
19
- Returns:
20
- - encoded_image (str): Base64 encoded string of the image data
21
- """
22
- if image_file is not None:
23
- # Convert the PIL Image object to bytes and encode in Base64
24
- buffered = io.BytesIO()
25
- image_file.save(buffered, format="JPEG")
26
- img_bytes = buffered.getvalue()
27
- encoded_image = base64.b64encode(img_bytes).decode("utf-8")
28
- return encoded_image
29
- else:
30
- raise FileNotFoundError("No file uploaded")
31
-
32
- def format_response(response_text):
33
- """
34
- Formats the model response to display each item as HTML elements.
35
- """
36
- response_text = re.sub(r"\*\*(.*?)\*\*", r"<p><strong>\1</strong></p>", response_text)
37
- response_text = re.sub(r"(?m)^\s*\*\s(.*)", r"<li>\1</li>", response_text)
38
- response_text = re.sub(r"(<li>.*?</li>)+", lambda match: f"<ul>{match.group(0)}</ul>", response_text, flags=re.DOTALL)
39
- response_text = re.sub(r"</p>(?=<p>)", r"</p><br>", response_text)
40
- response_text = re.sub(r"(\n|\\n)+", r"<br>", response_text)
41
- return response_text
42
 
43
  def generate_model_response(image_file, user_query):
44
  """
45
  Processes the uploaded image and user query to generate a response from the model.
46
-
47
  Parameters:
48
  - image_file: The uploaded image file.
49
  - user_query: The user's question about the image.
@@ -52,32 +27,34 @@ def generate_model_response(image_file, user_query):
52
  - str: The generated response from the model, formatted as HTML.
53
  """
54
  try:
55
- encoded_image = input_image_setup(image_file)
56
- except FileNotFoundError as e:
57
- return f"<p>{str(e)}</p>"
58
-
59
- assistant_prompt = """
60
- You are an expert nutritionist. Analyze the food items in the image and provide a detailed nutritional assessment:
61
-
62
- 1. **Identification**: List each food item.
63
- 2. **Portion & Calories**: Specify portion size and calories for each item.
64
- 3. **Total Calories**: Provide the total.
65
- 4. **Nutrient Breakdown**: Detail key nutrients.
66
- 5. **Health Evaluation**: Evaluate meal healthiness.
67
- 6. **Disclaimer**: "Nutritional info is approximate. Consult a nutritionist for precise advice."
68
-
69
- Format your response accordingly.
70
- """
71
-
72
- input_text = assistant_prompt + "\n\n" + user_query + "\n![Image](data:image/jpeg;base64," + encoded_image + ")"
73
-
74
- inputs = AutoTokenizer(input_text)#, return_tensors="pt")
75
-
76
- try:
77
  outputs = model.generate(**inputs)
78
- raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
- formatted_response = format_response(raw_response)
80
- return formatted_response
 
 
 
81
  except Exception as e:
82
  print(f"Error in generating response: {e}")
83
  return f"<p>An error occurred: {str(e)}</p>"
@@ -86,10 +63,10 @@ def generate_model_response(image_file, user_query):
86
  iface = gr.Interface(
87
  fn=generate_model_response,
88
  inputs=[
89
- gr.Image(type="pil", label="Upload Image"),
90
  gr.Textbox(label="Enter your question", placeholder="How many calories are in this food?")
91
  ],
92
- outputs=gr.HTML(label="Nutritional Assessment")
93
  )
94
 
95
- iface.launch(share=True)
 
1
  import re
2
  import base64
3
+ import io
4
+ import torch
5
  import gradio as gr
6
  from PIL import Image
7
+ from transformers import MllamaForConditionalGeneration, AutoProcessor
8
+
9
+ # Load the model and processor
10
+ model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
11
+ model = MllamaForConditionalGeneration.from_pretrained(
12
+ model_id,
13
+ torch_dtype=torch.bfloat16,
14
+ device_map="auto",
15
+ )
16
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def generate_model_response(image_file, user_query):
19
  """
20
  Processes the uploaded image and user query to generate a response from the model.
21
+
22
  Parameters:
23
  - image_file: The uploaded image file.
24
  - user_query: The user's question about the image.
 
27
  - str: The generated response from the model, formatted as HTML.
28
  """
29
  try:
30
+ # Load and prepare the image
31
+ raw_image = Image.open(image_file).convert("RGB")
32
+
33
+ # Prepare input for the model using the processor
34
+ conversation = [
35
+ {
36
+ "role": "user",
37
+ "content": [
38
+ {"type": "image", "url": "<|image|>"}, # Placeholder for image
39
+ {"type": "text", "text": user_query}
40
+ ]
41
+ }
42
+ ]
43
+
44
+ # Apply chat template to prepare inputs for the model
45
+ inputs = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
46
+
47
+ # Process the image and text inputs together
48
+ inputs = processor(inputs, raw_image, return_tensors="pt").to(model.device)
49
+
50
+ # Generate response from the model
 
51
  outputs = model.generate(**inputs)
52
+
53
+ # Decode and format the response
54
+ generated_text = processor.decode(outputs[0], skip_special_tokens=True)
55
+
56
+ return generated_text
57
+
58
  except Exception as e:
59
  print(f"Error in generating response: {e}")
60
  return f"<p>An error occurred: {str(e)}</p>"
 
63
  iface = gr.Interface(
64
  fn=generate_model_response,
65
  inputs=[
66
+ gr.Image(type="file", label="Upload Image"),
67
  gr.Textbox(label="Enter your question", placeholder="How many calories are in this food?")
68
  ],
69
+ outputs=gr.HTML(label="Response from Model"),
70
  )
71
 
72
+ iface.launch(share=True)