Pixtral-12B / app.py
aixsatoshi's picture
Update app.py
aaa844d verified
raw
history blame
5.33 kB
import gradio as gr
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk, ImageChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from huggingface_hub import snapshot_download
from pathlib import Path
import base64
import spaces
# モデルのダウンロードと準備
mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistralai/Pixtral-12B-2409",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
local_dir=mistral_models_path)
# トークナイザーとモデルのロード
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
model = Transformer.from_folder(mistral_models_path)
# 画像ファイルをbase64に変換するヘルパー関数
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_string
# 推論処理
@spaces.GPU
def mistral_inference(prompt, image_url=None, image_file=None):
if image_file is not None:
# 画像ファイルがアップロードされた場合
image_chunk = ImageChunk(image_base64=image_to_base64(image_file))
else:
# 画像URLが指定された場合
image_chunk = ImageURLChunk(image_url=image_url)
completion_request = ChatCompletionRequest(
messages=[UserMessage(content=[image_chunk, TextChunk(text=prompt)])]
)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
# 言語によるUIラベルの設定
def get_labels(language):
labels = {
'en': {
'title': "Pixtral Model Image Description",
'text_prompt': "Text Prompt",
'image_url': "Image URL (or leave blank if uploading an image)",
'image_upload': "Upload Image",
'output': "Model Output",
'image_display': "Input Image",
'submit': "Run Inference"
},
'zh': {
'title': "Pixtral模型图像描述",
'text_prompt': "文本提示",
'image_url': "图片网址 (如果上传图片,请留空)",
'image_upload': "上传图片",
'output': "模型输出",
'image_display': "输入图片",
'submit': "运行推理"
},
'jp': {
'title': "Pixtralモデルによる画像説明生成",
'text_prompt': "テキストプロンプト",
'image_url': "画像URL(画像をアップロードする場合は空白)",
'image_upload': "画像をアップロード",
'output': "モデルの出力結果",
'image_display': "入力された画像",
'submit': "推論を実行"
}
}
return labels[language]
# Gradioインターフェース
def process_input(text, image_url, image_file):
if image_file is not None:
result = mistral_inference(text, image_file=image_file)
image_display = f'<img src="data:image/png;base64,{image_to_base64(image_file)}" alt="Input Image" width="300">'
else:
result = mistral_inference(text, image_url=image_url)
image_display = f'<img src="{image_url}" alt="Input Image" width="300">'
return result, image_display
def update_ui(language):
labels = get_labels(language)
return labels['title'], labels['text_prompt'], labels['image_url'], labels['image_upload'], labels['output'], labels['image_display'], labels['submit']
with gr.Blocks() as demo:
language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
title = gr.Markdown("## Pixtral Model Image Description")
with gr.Row():
text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
image_url_input = gr.Textbox(label="Image URL (or leave blank if uploading an image)", placeholder="e.g. https://example.com/image.png")
image_file_input = gr.Image(label="Upload Image", type="filepath", optional=True)
result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20) # 高さ500ピクセルに相当するように調整
image_output = gr.HTML(label="Input Image") # 入力画像URLを表示するための場所
submit_button = gr.Button("Run Inference")
submit_button.click(process_input, inputs=[text_input, image_url_input, image_file_input], outputs=[result_output, image_output])
# 言語変更時にUIラベルを更新
language_choice.change(
fn=update_ui,
inputs=[language_choice],
outputs=[title, text_input, image_url_input, image_file_input, result_output, image_output, submit_button]
)
demo.launch()