shb777 commited on
Commit
4371bd7
·
1 Parent(s): dc4de1a

Fix chat history and sending image with every message

Browse files
Files changed (1) hide show
  1. app.py +82 -20
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import spaces
2
  import random
3
  import torch
 
4
  import gradio as gr
5
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
6
 
@@ -8,27 +9,28 @@ model_id = "ibm-granite/granite-vision-3.1-2b-preview"
8
  processor = LlavaNextProcessor.from_pretrained(model_id, use_fast=True)
9
  model = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
10
 
11
- def get_text_from_content(content):
12
- texts = []
13
- for item in content:
14
- if item["type"] == "text":
15
- texts.append(item["text"])
16
- elif item["type"] == "image":
17
- texts.append("<image>")
18
- return " ".join(texts)
19
 
20
  @spaces.GPU
21
  def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
22
- if conversation is None:
23
- conversation = []
24
-
 
 
 
25
  user_content = []
 
26
  if image is not None:
27
  if image.width > 512 or image.height > 512:
28
  image.thumbnail((512, 512))
29
  user_content.append({"type": "image", "image": image})
30
  if text and text.strip():
31
  user_content.append({"type": "text", "text": text.strip()})
 
32
  if not user_content:
33
  return conversation_display(conversation), conversation
34
 
@@ -37,6 +39,9 @@ def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversat
37
  "content": user_content
38
  })
39
 
 
 
 
40
  inputs = processor.apply_chat_template(
41
  conversation,
42
  add_generation_prompt=True,
@@ -59,29 +64,87 @@ def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversat
59
  generation_kwargs["do_sample"] = True
60
 
61
  output = model.generate(**inputs, **generation_kwargs)
62
- assistant_response = processor.decode(output[0], skip_special_tokens=True)
 
63
 
 
64
  conversation.append({
65
  "role": "assistant",
66
- "content": [{"type": "text", "text": assistant_response.strip()}]
67
  })
68
 
69
  return conversation_display(conversation), conversation
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def conversation_display(conversation):
72
  chat_history = []
73
  for msg in conversation:
74
  if msg["role"] == "user":
75
- user_text = get_text_from_content(msg["content"])
76
- elif msg["role"] == "assistant":
77
- assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip()
78
- chat_history.append({"role": "user", "content": user_text})
79
- chat_history.append({"role": "assistant", "content": assistant_text})
 
 
 
 
 
 
 
 
 
 
80
  return chat_history
81
 
82
  def clear_chat():
83
  return [], [], "", None
84
-
85
  with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
86
  gr.Markdown("# [Granite Vision 3.1 2B](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview)")
87
 
@@ -101,7 +164,6 @@ with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as
101
  send_button = gr.Button("Chat")
102
  clear_button = gr.Button("Clear Chat")
103
 
104
-
105
  state = gr.State([])
106
 
107
  send_button.click(
 
1
  import spaces
2
  import random
3
  import torch
4
+ import hashlib
5
  import gradio as gr
6
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
7
 
 
9
  processor = LlavaNextProcessor.from_pretrained(model_id, use_fast=True)
10
  model = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
11
 
12
+ SYSTEM_PROMPT = (
13
+ "A chat between a curious user and an artificial intelligence assistant. "
14
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
15
+ )
 
 
 
 
16
 
17
  @spaces.GPU
18
  def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
19
+ if conversation is None or conversation == []:
20
+ conversation = [{
21
+ "role": "system",
22
+ "content": [{"type": "text", "text": SYSTEM_PROMPT}]
23
+ }]
24
+
25
  user_content = []
26
+
27
  if image is not None:
28
  if image.width > 512 or image.height > 512:
29
  image.thumbnail((512, 512))
30
  user_content.append({"type": "image", "image": image})
31
  if text and text.strip():
32
  user_content.append({"type": "text", "text": text.strip()})
33
+
34
  if not user_content:
35
  return conversation_display(conversation), conversation
36
 
 
39
  "content": user_content
40
  })
41
 
42
+ conversation = preprocess_conversation(conversation)
43
+
44
+ # Generate input prompt using the chat template.
45
  inputs = processor.apply_chat_template(
46
  conversation,
47
  add_generation_prompt=True,
 
64
  generation_kwargs["do_sample"] = True
65
 
66
  output = model.generate(**inputs, **generation_kwargs)
67
+ raw_response = processor.decode(output[0], skip_special_tokens=True)
68
+ assistant_text = extract_answer(raw_response)
69
 
70
+ # Append the assistant's answer.
71
  conversation.append({
72
  "role": "assistant",
73
+ "content": [{"type": "text", "text": assistant_text}]
74
  })
75
 
76
  return conversation_display(conversation), conversation
77
 
78
+ def extract_answer(response):
79
+ if "<|assistant|>" in response:
80
+ return response.split("<|assistant|>")[-1].strip()
81
+ return response.strip()
82
+
83
+ def compute_image_hash(image):
84
+ image = image.convert("RGB")
85
+ image_bytes = image.tobytes()
86
+ return hashlib.md5(image_bytes).hexdigest()
87
+
88
+ def preprocess_conversation(conversation):
89
+ # Find the last sent image in previous user messages (excluding the latest message)
90
+ last_image_hash = None
91
+ for msg in reversed(conversation[:-1]):
92
+ if msg.get("role") == "user":
93
+ for item in msg.get("content", []):
94
+ if item.get("type") == "image" and item.get("image") is not None:
95
+ try:
96
+ last_image_hash = compute_image_hash(item["image"])
97
+ break
98
+ except Exception as e:
99
+ continue
100
+ if last_image_hash is not None:
101
+ break
102
+
103
+ # Process the latest user message.
104
+ latest_msg = conversation[-1]
105
+ if latest_msg.get("role") == "user":
106
+ new_content = []
107
+ for item in latest_msg.get("content", []):
108
+ if item.get("type") == "image" and item.get("image") is not None:
109
+ try:
110
+ current_hash = compute_image_hash(item["image"])
111
+ except Exception as e:
112
+ current_hash = None
113
+ # Remove the image if it matches the last sent image.
114
+ if last_image_hash is not None and current_hash is not None and current_hash == last_image_hash:
115
+ continue
116
+ else:
117
+ new_content.append(item)
118
+ else:
119
+ new_content.append(item)
120
+ latest_msg["content"] = new_content
121
+
122
+ return conversation
123
+
124
  def conversation_display(conversation):
125
  chat_history = []
126
  for msg in conversation:
127
  if msg["role"] == "user":
128
+ texts = []
129
+ for item in msg["content"]:
130
+ if item["type"] == "image":
131
+ texts.append("<image>")
132
+ elif item["type"] == "text":
133
+ texts.append(item["text"])
134
+ chat_history.append({
135
+ "role": "user",
136
+ "content": "\n".join(texts)
137
+ })
138
+ else:
139
+ chat_history.append({
140
+ "role": msg["role"],
141
+ "content": msg["content"][0]["text"]
142
+ })
143
  return chat_history
144
 
145
  def clear_chat():
146
  return [], [], "", None
147
+
148
  with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
149
  gr.Markdown("# [Granite Vision 3.1 2B](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview)")
150
 
 
164
  send_button = gr.Button("Chat")
165
  clear_button = gr.Button("Clear Chat")
166
 
 
167
  state = gr.State([])
168
 
169
  send_button.click(