File size: 5,331 Bytes
9b25f0e
b4fa047
 
 
5abe1dc
b4fa047
 
 
5abe1dc
 
b4fa047
 
 
 
 
5abe1dc
b4fa047
 
 
aaa844d
b4fa047
 
 
 
5abe1dc
 
 
 
 
 
b4fa047
 
5abe1dc
 
 
 
 
 
 
 
b4fa047
5abe1dc
b4fa047
 
 
 
 
 
c5d3fee
b4fa047
 
 
 
13669f6
 
 
 
 
 
5abe1dc
 
13669f6
 
 
 
 
 
 
5abe1dc
 
13669f6
 
 
 
 
 
 
5abe1dc
 
13669f6
 
 
 
 
 
 
 
5abe1dc
 
 
 
 
 
 
 
 
b4fa047
13669f6
 
5abe1dc
13669f6
b4fa047
13669f6
b4fa047
13669f6
b4fa047
13669f6
b4fa047
5abe1dc
 
 
13669f6
929aae5
3f86c47
13669f6
3f86c47
5abe1dc
b4fa047
13669f6
 
 
 
5abe1dc
13669f6
 
5abe1dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk, ImageChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from huggingface_hub import snapshot_download
from pathlib import Path
import base64
import spaces

# モデルのダウンロードと準備
mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
mistral_models_path.mkdir(parents=True, exist_ok=True)

snapshot_download(repo_id="mistralai/Pixtral-12B-2409", 
                  allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], 
                  local_dir=mistral_models_path)


# トークナイザーとモデルのロード
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
model = Transformer.from_folder(mistral_models_path)

# 画像ファイルをbase64に変換するヘルパー関数
def image_to_base64(image_path):
    with open(image_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
    return encoded_string

# 推論処理
@spaces.GPU
def mistral_inference(prompt, image_url=None, image_file=None):
    if image_file is not None:
        # 画像ファイルがアップロードされた場合
        image_chunk = ImageChunk(image_base64=image_to_base64(image_file))
    else:
        # 画像URLが指定された場合
        image_chunk = ImageURLChunk(image_url=image_url)

    completion_request = ChatCompletionRequest(
        messages=[UserMessage(content=[image_chunk, TextChunk(text=prompt)])]
    )
    
    encoded = tokenizer.encode_chat_completion(completion_request)
    images = encoded.images
    tokens = encoded.tokens

    out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
    result = tokenizer.decode(out_tokens[0])
    
    return result

# 言語によるUIラベルの設定
def get_labels(language):
    labels = {
        'en': {
            'title': "Pixtral Model Image Description",
            'text_prompt': "Text Prompt",
            'image_url': "Image URL (or leave blank if uploading an image)",
            'image_upload': "Upload Image",
            'output': "Model Output",
            'image_display': "Input Image",
            'submit': "Run Inference"
        },
        'zh': {
            'title': "Pixtral模型图像描述",
            'text_prompt': "文本提示",
            'image_url': "图片网址 (如果上传图片,请留空)",
            'image_upload': "上传图片",
            'output': "模型输出",
            'image_display': "输入图片",
            'submit': "运行推理"
        },
        'jp': {
            'title': "Pixtralモデルによる画像説明生成",
            'text_prompt': "テキストプロンプト",
            'image_url': "画像URL(画像をアップロードする場合は空白)",
            'image_upload': "画像をアップロード",
            'output': "モデルの出力結果",
            'image_display': "入力された画像",
            'submit': "推論を実行"
        }
    }
    return labels[language]

# Gradioインターフェース
def process_input(text, image_url, image_file):
    if image_file is not None:
        result = mistral_inference(text, image_file=image_file)
        image_display = f'<img src="data:image/png;base64,{image_to_base64(image_file)}" alt="Input Image" width="300">'
    else:
        result = mistral_inference(text, image_url=image_url)
        image_display = f'<img src="{image_url}" alt="Input Image" width="300">'
    
    return result, image_display

def update_ui(language):
    labels = get_labels(language)
    return labels['title'], labels['text_prompt'], labels['image_url'], labels['image_upload'], labels['output'], labels['image_display'], labels['submit']

with gr.Blocks() as demo:
    language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
    
    title = gr.Markdown("## Pixtral Model Image Description")
    with gr.Row():
        text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
    
    image_url_input = gr.Textbox(label="Image URL (or leave blank if uploading an image)", placeholder="e.g. https://example.com/image.png")
    image_file_input = gr.Image(label="Upload Image", type="filepath", optional=True)

    result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20)  # 高さ500ピクセルに相当するように調整
    image_output = gr.HTML(label="Input Image")  # 入力画像URLを表示するための場所

    submit_button = gr.Button("Run Inference")

    submit_button.click(process_input, inputs=[text_input, image_url_input, image_file_input], outputs=[result_output, image_output])

    # 言語変更時にUIラベルを更新
    language_choice.change(
        fn=update_ui, 
        inputs=[language_choice], 
        outputs=[title, text_input, image_url_input, image_file_input, result_output, image_output, submit_button]
    )

demo.launch()