File size: 4,592 Bytes
14deb22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import spaces

import time
from threading import Thread

import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)
from io import BytesIO
import requests
import os
from conversation import Conversation, SeparatorStyle

model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"

disable_torch_init()
model_name = get_model_name_from_path(model_id)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_id, None, model_name
)

def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    elif os.path.exists(image_file):
        image = Image.open(image_file).convert("RGB")
    else:
        raise FileNotFoundError(f"Image file {image_file} not found.")
    return image

def infer_single_image(model_id, image_file, prompt):
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in prompt:
        if model.config.mm_use_im_start_end:
            prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
        else:
            prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
    else:
        if model.config.mm_use_im_start_end:
            prompt = image_token_se + "\n" + prompt
        else:
            prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

    conv = Conversation(
        system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.""",
        roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
        version="llama3",
        messages=[],
        offset=0,
        sep_style=SeparatorStyle.MPT,
        sep="<|eot_id|>",
    )
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    full_prompt = conv.get_prompt()

    print("full prompt: ", full_prompt)

    image = load_image(image_file)
    image_tensor = process_images(
        [image],
        image_processor,
        model.config
    ).to(model.device, dtype=torch.float16)

    input_ids = (
        tokenizer_image_token(full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            image_sizes=[image.size],
            do_sample=False,
            max_new_tokens=512,
            use_cache=True,
        )

    output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    return output

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    if message["files"]:
        if type(message["files"][-1]) == dict:
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        for hist in history:
            if type(hist[0]) == tuple:
                image = hist[0][0]
    try:
        if image is None:
            gr.Error("You need to upload an image for LLaVA to work.")
    except NameError:
        gr.Error("You need to upload an image for LLaVA to work.")

    prompt = message['text']

    result = infer_single_image(model_id, image, prompt)
    yield result

chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)

with gr.Blocks(fill_height=True) as demo:
    gr.ChatInterface(
        fn=bot_streaming,
        title="LLaVA Llama-3-8B",
        examples=[{"text": "Çiçeğin üzerinde ne var?", "files": ["./bee.jpg"]},
                  {"text": "Bu tatlı nasıl yapılır?", "files": ["./baklava.png"]}],
        description="",
        stop_btn="Stop Generation",
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)