File size: 4,022 Bytes
9b25f0e
b4fa047
 
 
 
 
 
 
 
 
 
 
 
 
953563e
b4fa047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5d3fee
b4fa047
 
 
 
13669f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4fa047
 
929aae5
b4fa047
13669f6
 
 
 
b4fa047
13669f6
b4fa047
13669f6
b4fa047
13669f6
 
b4fa047
13669f6
929aae5
3f86c47
13669f6
3f86c47
 
b4fa047
13669f6
 
 
 
 
 
 
b4fa047
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import spaces
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from huggingface_hub import snapshot_download
from pathlib import Path

# モデルのダウンロードと準備
mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
mistral_models_path.mkdir(parents=True, exist_ok=True)

snapshot_download(repo_id="mistral-community/pixtral-12b-240910", 
                  allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], 
                  local_dir=mistral_models_path)

# トークナイザーとモデルのロード
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
model = Transformer.from_folder(mistral_models_path)

# 推論処理
@spaces.GPU
def mistral_inference(prompt, image_url):
    completion_request = ChatCompletionRequest(
        messages=[UserMessage(content=[ImageURLChunk(image_url=image_url), TextChunk(text=prompt)])]
    )
    
    encoded = tokenizer.encode_chat_completion(completion_request)
    images = encoded.images
    tokens = encoded.tokens

    out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
    result = tokenizer.decode(out_tokens[0])
    
    return result

# 言語によるUIラベルの設定
def get_labels(language):
    labels = {
        'en': {
            'title': "Pixtral Model Image Description",
            'text_prompt': "Text Prompt",
            'image_url': "Image URL",
            'output': "Model Output",
            'image_display': "Input Image",
            'submit': "Run Inference"
        },
        'zh': {
            'title': "Pixtral模型图像描述",
            'text_prompt': "文本提示",
            'image_url': "图片网址",
            'output': "模型输出",
            'image_display': "输入图片",
            'submit': "运行推理"
        },
        'jp': {
            'title': "Pixtralモデルによる画像説明生成",
            'text_prompt': "テキストプロンプト",
            'image_url': "画像URL",
            'output': "モデルの出力結果",
            'image_display': "入力された画像",
            'submit': "推論を実行"
        }
    }
    return labels[language]

# Gradioインターフェース
def process_input(text, image_url):
    result = mistral_inference(text, image_url)
    return result, f'<img src="{image_url}" alt="Input Image" width="300">'

def update_ui(language):
    labels = get_labels(language)
    return labels['title'], labels['text_prompt'], labels['image_url'], labels['output'], labels['image_display'], labels['submit']

with gr.Blocks() as demo:
    language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
    
    title = gr.Markdown("## Pixtral Model Image Description")
    with gr.Row():
        text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
        image_input = gr.Textbox(label="Image URL", placeholder="e.g. https://example.com/image.png")
    
    result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20)  # 高さ500ピクセルに相当するように調整
    image_output = gr.HTML(label="Input Image")  # 入力画像URLを表示するための場所

    submit_button = gr.Button("Run Inference")

    submit_button.click(process_input, inputs=[text_input, image_input], outputs=[result_output, image_output])

    # 言語変更時にUIラベルを更新
    language_choice.change(
        fn=update_ui, 
        inputs=[language_choice], 
        outputs=[title, text_input, image_input, result_output, image_output, submit_button]
    )

demo.launch()