erndgn commited on
Commit
14deb22
·
verified ·
1 Parent(s): 07f33e0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import time
4
+ from threading import Thread
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import AutoProcessor
10
+ from llava.constants import (
11
+ IMAGE_TOKEN_INDEX,
12
+ DEFAULT_IMAGE_TOKEN,
13
+ DEFAULT_IM_START_TOKEN,
14
+ DEFAULT_IM_END_TOKEN,
15
+ IMAGE_PLACEHOLDER,
16
+ )
17
+ from llava.model.builder import load_pretrained_model
18
+ from llava.utils import disable_torch_init
19
+ from llava.mm_utils import (
20
+ process_images,
21
+ tokenizer_image_token,
22
+ get_model_name_from_path,
23
+ )
24
+ from io import BytesIO
25
+ import requests
26
+ import os
27
+ from conversation import Conversation, SeparatorStyle
28
+
29
+ model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"
30
+
31
+ disable_torch_init()
32
+ model_name = get_model_name_from_path(model_id)
33
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
34
+ model_id, None, model_name
35
+ )
36
+
37
+ def load_image(image_file):
38
+ if image_file.startswith("http") or image_file.startswith("https"):
39
+ response = requests.get(image_file)
40
+ image = Image.open(BytesIO(response.content)).convert("RGB")
41
+ elif os.path.exists(image_file):
42
+ image = Image.open(image_file).convert("RGB")
43
+ else:
44
+ raise FileNotFoundError(f"Image file {image_file} not found.")
45
+ return image
46
+
47
+ def infer_single_image(model_id, image_file, prompt):
48
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
49
+ if IMAGE_PLACEHOLDER in prompt:
50
+ if model.config.mm_use_im_start_end:
51
+ prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
52
+ else:
53
+ prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
54
+ else:
55
+ if model.config.mm_use_im_start_end:
56
+ prompt = image_token_se + "\n" + prompt
57
+ else:
58
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
59
+
60
+ conv = Conversation(
61
+ system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.""",
62
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
63
+ version="llama3",
64
+ messages=[],
65
+ offset=0,
66
+ sep_style=SeparatorStyle.MPT,
67
+ sep="<|eot_id|>",
68
+ )
69
+ conv.append_message(conv.roles[0], prompt)
70
+ conv.append_message(conv.roles[1], None)
71
+ full_prompt = conv.get_prompt()
72
+
73
+ print("full prompt: ", full_prompt)
74
+
75
+ image = load_image(image_file)
76
+ image_tensor = process_images(
77
+ [image],
78
+ image_processor,
79
+ model.config
80
+ ).to(model.device, dtype=torch.float16)
81
+
82
+ input_ids = (
83
+ tokenizer_image_token(full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
84
+ .unsqueeze(0)
85
+ .cuda()
86
+ )
87
+
88
+ with torch.inference_mode():
89
+ output_ids = model.generate(
90
+ input_ids,
91
+ images=image_tensor,
92
+ image_sizes=[image.size],
93
+ do_sample=False,
94
+ max_new_tokens=512,
95
+ use_cache=True,
96
+ )
97
+
98
+ output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
99
+ return output
100
+
101
+ @spaces.GPU
102
+ def bot_streaming(message, history):
103
+ print(message)
104
+ if message["files"]:
105
+ if type(message["files"][-1]) == dict:
106
+ image = message["files"][-1]["path"]
107
+ else:
108
+ image = message["files"][-1]
109
+ else:
110
+ for hist in history:
111
+ if type(hist[0]) == tuple:
112
+ image = hist[0][0]
113
+ try:
114
+ if image is None:
115
+ gr.Error("You need to upload an image for LLaVA to work.")
116
+ except NameError:
117
+ gr.Error("You need to upload an image for LLaVA to work.")
118
+
119
+ prompt = message['text']
120
+
121
+ result = infer_single_image(model_id, image, prompt)
122
+ yield result
123
+
124
+ chatbot = gr.Chatbot(scale=1)
125
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
126
+
127
+ with gr.Blocks(fill_height=True) as demo:
128
+ gr.ChatInterface(
129
+ fn=bot_streaming,
130
+ title="LLaVA Llama-3-8B",
131
+ examples=[{"text": "Çiçeğin üzerinde ne var?", "files": ["./bee.jpg"]},
132
+ {"text": "Bu tatlı nasıl yapılır?", "files": ["./baklava.png"]}],
133
+ description="",
134
+ stop_btn="Stop Generation",
135
+ multimodal=True,
136
+ textbox=chat_input,
137
+ chatbot=chatbot,
138
+ )
139
+
140
+ demo.queue(api_open=False)
141
+ demo.launch(show_api=False, share=False)