hiyouga commited on
Commit
f75fe7f
·
verified ·
1 Parent(s): 517540b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from typing import Dict
3
+
4
+ import gradio as gr
5
+ import spaces
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer
9
+
10
+
11
+ TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.2</center></h1>"
12
+
13
+ DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/BUAADreamer/PaliGemma-3B-Chat-v0.2' target='_blank'>our model page</a> for details.</center></h3>"
14
+
15
+ CSS = """
16
+ .duplicate-button {
17
+ margin: auto !important;
18
+ color: white !important;
19
+ background: black !important;
20
+ border-radius: 100vh !important;
21
+ }
22
+ """
23
+
24
+
25
+ model_id = "BUAADreamer/PaliGemma-3B-Chat-v0.2"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+ processor = AutoProcessor.from_pretrained(model_id)
28
+ model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
29
+
30
+
31
+ @spaces.GPU
32
+ def stream_chat(message: Dict[str, str], history: list):
33
+ # Turn 1:
34
+ # {'text': 'what is this', 'files': ['image-xxx.jpg']}
35
+ # []
36
+
37
+ # Turn 2:
38
+ # {'text': 'continue?', 'files': []}
39
+ # [[('image-xxx.jpg',), None], ['what is this', 'a image.']]
40
+
41
+ image_path = None
42
+ if len(message["files"]) != 0:
43
+ image_path = message["files"][0]
44
+
45
+ if len(history) != 0 and isinstance(history[0][0], tuple):
46
+ image_path = history[0][0][0]
47
+ history = history[1:]
48
+
49
+ if image_path is not None:
50
+ image = Image.open(image_path)
51
+ else:
52
+ image = Image.new("RGB", (100, 100), (255, 255, 255))
53
+
54
+ pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]
55
+
56
+ conversation = []
57
+ for prompt, answer in history:
58
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
59
+
60
+ conversation.append({"role": "user", "content": message["text"]})
61
+
62
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
63
+ image_token_id = tokenizer.convert_tokens_to_ids("<image>")
64
+ image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
65
+ input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
66
+
67
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
68
+
69
+ generate_kwargs = dict(
70
+ input_ids=input_ids,
71
+ pixel_values=pixel_values,
72
+ streamer=streamer,
73
+ max_new_tokens=256,
74
+ )
75
+
76
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
77
+ t.start()
78
+
79
+ output = ""
80
+ for new_token in streamer:
81
+ output += new_token
82
+ yield output
83
+
84
+
85
+ chatbot = gr.Chatbot(height=450)
86
+
87
+ with gr.Blocks(css=CSS) as demo:
88
+ gr.HTML(TITLE)
89
+ gr.HTML(DESCRIPTION)
90
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
91
+ gr.ChatInterface(
92
+ fn=stream_chat,
93
+ multimodal=True,
94
+ chatbot=chatbot,
95
+ fill_height=True,
96
+ cache_examples=False,
97
+ )
98
+
99
+
100
+ if __name__ == "__main__":
101
+ demo.launch()