File size: 2,244 Bytes
88d793f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
from gradio.data_classes import FileData
from huggingface_hub import snapshot_download
from pathlib import Path
import base64
import spaces
import os

from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate

from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest

models_path = Path.home().joinpath('pixtral', 'Pixtral')
models_path.mkdir(parents=True, exist_ok=True)

snapshot_download(repo_id="mistral-community/pixtral-12b-240910", 
                  allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], 
                  local_dir=models_path)

tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json")
model = Transformer.from_folder(models_path)

def image_to_base64(image_path):
    with open(image_path, 'rb') as img:
        encoded_string = base64.b64encode(img.read()).decode('utf-8')
    return f"data:image/jpeg;base64,{encoded_string}"

@spaces.GPU
def run_inference(image_url, prompt):
    base64 = image_to_base64(image_url)
    completion_request = ChatCompletionRequest(messages=[UserMessage(content=[ImageURLChunk(image_url=base64), TextChunk(text=prompt)])])
    
    encoded = tokenizer.encode_chat_completion(completion_request)
    
    images = encoded.images
    tokens = encoded.tokens
    
    out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
    result = tokenizer.decode(out_tokens[0])
    return [[prompt, result]]

with gr.Blocks() as demo:
    with gr.Row():
        image_box = gr.Image(type="filepath")
   
        chatbot = gr.Chatbot(
            scale = 2,
            height=750
        )
    text_box = gr.Textbox(
            placeholder="Enter text and press enter, or upload an image",
            container=False,
        )


    btn = gr.Button("Submit")
    clicked = btn.click(run_inference,
                        [image_box,text_box],
                        chatbot
                        )

demo.queue().launch()