aixsatoshi commited on
Commit
5abe1dc
·
verified ·
1 Parent(s): d5e19a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -16
app.py CHANGED
@@ -1,18 +1,19 @@
1
  import gradio as gr
2
- import spaces
3
  from mistral_inference.transformer import Transformer
4
  from mistral_inference.generate import generate
5
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
6
- from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
7
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
8
  from huggingface_hub import snapshot_download
9
  from pathlib import Path
 
 
10
 
11
  # モデルのダウンロードと準備
12
  mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
13
  mistral_models_path.mkdir(parents=True, exist_ok=True)
14
 
15
- snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
16
  allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
17
  local_dir=mistral_models_path)
18
 
@@ -20,11 +21,24 @@ snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
20
  tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
21
  model = Transformer.from_folder(mistral_models_path)
22
 
 
 
 
 
 
 
23
  # 推論処理
24
  @spaces.GPU
25
- def mistral_inference(prompt, image_url):
 
 
 
 
 
 
 
26
  completion_request = ChatCompletionRequest(
27
- messages=[UserMessage(content=[ImageURLChunk(image_url=image_url), TextChunk(text=prompt)])]
28
  )
29
 
30
  encoded = tokenizer.encode_chat_completion(completion_request)
@@ -42,7 +56,8 @@ def get_labels(language):
42
  'en': {
43
  'title': "Pixtral Model Image Description",
44
  'text_prompt': "Text Prompt",
45
- 'image_url': "Image URL",
 
46
  'output': "Model Output",
47
  'image_display': "Input Image",
48
  'submit': "Run Inference"
@@ -50,7 +65,8 @@ def get_labels(language):
50
  'zh': {
51
  'title': "Pixtral模型图像描述",
52
  'text_prompt': "文本提示",
53
- 'image_url': "图片网址",
 
54
  'output': "模型输出",
55
  'image_display': "输入图片",
56
  'submit': "运行推理"
@@ -58,7 +74,8 @@ def get_labels(language):
58
  'jp': {
59
  'title': "Pixtralモデルによる画像説明生成",
60
  'text_prompt': "テキストプロンプト",
61
- 'image_url': "画像URL",
 
62
  'output': "モデルの出力結果",
63
  'image_display': "入力された画像",
64
  'submit': "推論を実行"
@@ -67,13 +84,19 @@ def get_labels(language):
67
  return labels[language]
68
 
69
  # Gradioインターフェース
70
- def process_input(text, image_url):
71
- result = mistral_inference(text, image_url)
72
- return result, f'<img src="{image_url}" alt="Input Image" width="300">'
 
 
 
 
 
 
73
 
74
  def update_ui(language):
75
  labels = get_labels(language)
76
- return labels['title'], labels['text_prompt'], labels['image_url'], labels['output'], labels['image_display'], labels['submit']
77
 
78
  with gr.Blocks() as demo:
79
  language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
@@ -81,20 +104,22 @@ with gr.Blocks() as demo:
81
  title = gr.Markdown("## Pixtral Model Image Description")
82
  with gr.Row():
83
  text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
84
- image_input = gr.Textbox(label="Image URL", placeholder="e.g. https://example.com/image.png")
85
 
 
 
 
86
  result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20) # 高さ500ピクセルに相当するように調整
87
  image_output = gr.HTML(label="Input Image") # 入力画像URLを表示するための場所
88
 
89
  submit_button = gr.Button("Run Inference")
90
 
91
- submit_button.click(process_input, inputs=[text_input, image_input], outputs=[result_output, image_output])
92
 
93
  # 言語変更時にUIラベルを更新
94
  language_choice.change(
95
  fn=update_ui,
96
  inputs=[language_choice],
97
- outputs=[title, text_input, image_input, result_output, image_output, submit_button]
98
  )
99
 
100
- demo.launch()
 
1
  import gradio as gr
 
2
  from mistral_inference.transformer import Transformer
3
  from mistral_inference.generate import generate
4
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
5
+ from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk, ImageChunk
6
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
7
  from huggingface_hub import snapshot_download
8
  from pathlib import Path
9
+ import base64
10
+ import spaces
11
 
12
  # モデルのダウンロードと準備
13
  mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
14
  mistral_models_path.mkdir(parents=True, exist_ok=True)
15
 
16
+ snapshot_download(repo_id="mistralai/Pixtral-12B-2409",
17
  allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
18
  local_dir=mistral_models_path)
19
 
 
21
  tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
22
  model = Transformer.from_folder(mistral_models_path)
23
 
24
+ # 画像ファイルをbase64に変換するヘルパー関数
25
+ def image_to_base64(image_path):
26
+ with open(image_path, "rb") as image_file:
27
+ encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
28
+ return encoded_string
29
+
30
  # 推論処理
31
  @spaces.GPU
32
+ def mistral_inference(prompt, image_url=None, image_file=None):
33
+ if image_file is not None:
34
+ # 画像ファイルがアップロードされた場合
35
+ image_chunk = ImageChunk(image_base64=image_to_base64(image_file))
36
+ else:
37
+ # 画像URLが指定された場合
38
+ image_chunk = ImageURLChunk(image_url=image_url)
39
+
40
  completion_request = ChatCompletionRequest(
41
+ messages=[UserMessage(content=[image_chunk, TextChunk(text=prompt)])]
42
  )
43
 
44
  encoded = tokenizer.encode_chat_completion(completion_request)
 
56
  'en': {
57
  'title': "Pixtral Model Image Description",
58
  'text_prompt': "Text Prompt",
59
+ 'image_url': "Image URL (or leave blank if uploading an image)",
60
+ 'image_upload': "Upload Image",
61
  'output': "Model Output",
62
  'image_display': "Input Image",
63
  'submit': "Run Inference"
 
65
  'zh': {
66
  'title': "Pixtral模型图像描述",
67
  'text_prompt': "文本提示",
68
+ 'image_url': "图片网址 (如果上传图片,请留空)",
69
+ 'image_upload': "上传图片",
70
  'output': "模型输出",
71
  'image_display': "输入图片",
72
  'submit': "运行推理"
 
74
  'jp': {
75
  'title': "Pixtralモデルによる画像説明生成",
76
  'text_prompt': "テキストプロンプト",
77
+ 'image_url': "画像URL(画像をアップロードする場合は空白)",
78
+ 'image_upload': "画像をアップロード",
79
  'output': "モデルの出力結果",
80
  'image_display': "入力された画像",
81
  'submit': "推論を実行"
 
84
  return labels[language]
85
 
86
  # Gradioインターフェース
87
+ def process_input(text, image_url, image_file):
88
+ if image_file is not None:
89
+ result = mistral_inference(text, image_file=image_file)
90
+ image_display = f'<img src="data:image/png;base64,{image_to_base64(image_file)}" alt="Input Image" width="300">'
91
+ else:
92
+ result = mistral_inference(text, image_url=image_url)
93
+ image_display = f'<img src="{image_url}" alt="Input Image" width="300">'
94
+
95
+ return result, image_display
96
 
97
  def update_ui(language):
98
  labels = get_labels(language)
99
+ return labels['title'], labels['text_prompt'], labels['image_url'], labels['image_upload'], labels['output'], labels['image_display'], labels['submit']
100
 
101
  with gr.Blocks() as demo:
102
  language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
 
104
  title = gr.Markdown("## Pixtral Model Image Description")
105
  with gr.Row():
106
  text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
 
107
 
108
+ image_url_input = gr.Textbox(label="Image URL (or leave blank if uploading an image)", placeholder="e.g. https://example.com/image.png")
109
+ image_file_input = gr.Image(label="Upload Image", type="filepath", optional=True)
110
+
111
  result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20) # 高さ500ピクセルに相当するように調整
112
  image_output = gr.HTML(label="Input Image") # 入力画像URLを表示するための場所
113
 
114
  submit_button = gr.Button("Run Inference")
115
 
116
+ submit_button.click(process_input, inputs=[text_input, image_url_input, image_file_input], outputs=[result_output, image_output])
117
 
118
  # 言語変更時にUIラベルを更新
119
  language_choice.change(
120
  fn=update_ui,
121
  inputs=[language_choice],
122
+ outputs=[title, text_input, image_url_input, image_file_input, result_output, image_output, submit_button]
123
  )
124
 
125
+ demo.launch()