ethux's picture
Create app.py
88d793f verified
raw
history blame
2.24 kB
import gradio as gr
from gradio.data_classes import FileData
from huggingface_hub import snapshot_download
from pathlib import Path
import base64
import spaces
import os
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
models_path = Path.home().joinpath('pixtral', 'Pixtral')
models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
local_dir=models_path)
tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json")
model = Transformer.from_folder(models_path)
def image_to_base64(image_path):
with open(image_path, 'rb') as img:
encoded_string = base64.b64encode(img.read()).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_string}"
@spaces.GPU
def run_inference(image_url, prompt):
base64 = image_to_base64(image_url)
completion_request = ChatCompletionRequest(messages=[UserMessage(content=[ImageURLChunk(image_url=base64), TextChunk(text=prompt)])])
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return [[prompt, result]]
with gr.Blocks() as demo:
with gr.Row():
image_box = gr.Image(type="filepath")
chatbot = gr.Chatbot(
scale = 2,
height=750
)
text_box = gr.Textbox(
placeholder="Enter text and press enter, or upload an image",
container=False,
)
btn = gr.Button("Submit")
clicked = btn.click(run_inference,
[image_box,text_box],
chatbot
)
demo.queue().launch()