Holy-fox commited on
Commit
8fd92ee
·
verified ·
1 Parent(s): 56e60b7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +49 -141
README.md CHANGED
@@ -24,182 +24,90 @@ Gemma 3ファミリーと同様に、テキスト入力と画像入力の両方
24
  pip install -U transformers accelerate Pillow requests torch
25
  ```
26
 
27
- また、vLLMを使用する場合は、vLLMをインストールしてください。
28
-
29
- ```bash
30
- pip install vllm
31
- ```
32
-
33
- ### vLLMを使用した推論
34
-
35
- vLLMを使用することで、高速な推論が可能です。
36
 
37
  ```python
38
- from vllm import LLM, SamplingParams
 
 
 
39
 
40
- # モデルID
41
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.2"
42
 
43
- # サンプリングパラメータ (必要に応じて調整)
44
- sampling_params = SamplingParams(temperature=0.1, top_p=0.9, max_tokens=512)
 
45
 
46
- # LLMインスタンスの作成
47
- llm = LLM(model=model_id, trust_remote_code=True) # Gemma 3にはリモートコード実行が必要な場合があります
48
 
49
- # プロンプトの準備 (Gemma 3のチャットテンプレート形式を推奨)
50
- # vLLMは通常、tokenizerからチャットテンプレートを自動適用します
51
- # 手動で適用する場合は tokenizer.apply_chat_template を使用します
52
  messages = [
53
- {"role": "system", "content": "あなたは親切なAIアシスタントです。"},
54
- {"role": "user", "content": "日本の首都はどこですか?その都市の有名な観光地を3つ教えてください。"}
 
 
 
 
 
 
 
 
 
55
  ]
56
 
57
- # Hugging Face tokenizerを使ってチャットテンプレートを適用
58
- from transformers import AutoTokenizer
59
- tokenizer = AutoTokenizer.from_pretrained(model_id)
60
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
 
62
- # 推論の実行
63
- outputs = llm.generate(prompt, sampling_params)
64
 
65
- # 結果の表示
66
- for output in outputs:
67
- prompt_disp = output.prompt
68
- generated_text = output.outputs[0].text
69
- print(f"Prompt: {prompt_disp!r}")
70
- print(f"Generated text: {generated_text!r}")
71
 
 
 
72
  ```
73
-
74
- ### Transformersを使用した推論 (テキストのみ)
75
-
76
- テキスト入力のみで推論を行う場合の基本的なコードです。System PromptとUser Promptを使用します。
77
 
78
  ```python
 
79
  import torch
80
- from transformers import AutoTokenizer, AutoModelForCausalLM
81
 
82
- # モデルID
83
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.2"
84
 
85
- # トークナイザーとモデルのロード
86
- tokenizer = AutoTokenizer.from_pretrained(model_id)
87
- # Gemma 3 4B はメモリ要求が高いため、bf16を使用し、可能であれば複数GPUに分散します
88
- model = AutoModelForCausalLM.from_pretrained(
89
- model_id,
90
- device_map="auto",
91
- torch_dtype=torch.bfloat16,
92
- trust_remote_code=True # Gemma 3にはリモートコード実行が必要な場合があります
93
- )
94
- model.eval()
95
-
96
- # チャットメッセージの準備
97
- messages = [
98
- {"role": "system", "content": "あなたは知識豊富で、質問に対して詳細に答えるAIアシスタントです。"},
99
- {"role": "user", "content": "機械学習とは何か、初心者にもわかるように簡単に説明してください。"}
100
- ]
101
-
102
- # チャットテンプレートを適用し、テンソルに変換
103
- inputs = tokenizer.apply_chat_template(
104
- messages,
105
- add_generation_prompt=True,
106
- tokenize=True,
107
- return_tensors="pt"
108
- ).to(model.device)
109
-
110
- # 入力トークン数の取得 (生成部分のみを後で抽出するため)
111
- input_len = inputs.shape[-1]
112
-
113
- # 推論の実行
114
- with torch.inference_mode():
115
- outputs = model.generate(
116
- inputs,
117
- max_new_tokens=512, # 最大生成トークン数
118
- do_sample=True, # サンプリングを使用する場合
119
- temperature=0.7, # 生成の多様性
120
- top_p=0.9 # Top-pサンプリング
121
- )
122
-
123
- # 生成されたトークンのみをデコード
124
- generated_tokens = outputs[0][input_len:]
125
- response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
126
-
127
- print("--- モデルの応答 ---")
128
- print(response)
129
- ```
130
-
131
- ### Transformersを使用した推論 (画像 + テキスト)
132
-
133
- 画像とテキストを組み合わせて入力し、推論を行う場合のコードです。
134
-
135
- ```python
136
- import torch
137
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration # または AutoModelForCausalLM
138
- from PIL import Image
139
- import requests
140
-
141
- # モデルID
142
- model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.2"
143
 
144
- # プロセッサーとモデルのロード
145
  processor = AutoProcessor.from_pretrained(model_id)
146
- # Gemma 3 4B はメモリ要求が高いため、bf16を使用し、可能であれば複数GPUに分散します
147
- model = Gemma3ForConditionalGeneration.from_pretrained( # Gemma 3の推奨クラス
148
- model_id,
149
- device_map="auto",
150
- torch_dtype=torch.bfloat16,
151
- trust_remote_code=True # Gemma 3にはリモートコード実行が必要な場合があります
152
- )
153
- model.eval()
154
-
155
- # 画像の準備 (例: URLからロード)
156
- image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
157
- image = Image.open(requests.get(image_url, stream=True).raw)
158
-
159
- # チャットメッセージの準備 (画像とテキストを含む)
160
  messages = [
161
  {
162
  "role": "system",
163
- "content": [{"type": "text", "text": "あなたは画像を詳細に説明するAIアシスタントです。"}]
164
  },
165
  {
166
  "role": "user",
167
  "content": [
168
- {"type": "image", "image": image}, # PILイメージオブジェクトを渡す
169
- {"type": "text", "text": "この画像に写っている昆虫は何ですか?花についても説明してください。"}
170
  ]
171
  }
172
  ]
173
 
174
- # チャットテンプレートを適用し、テンソルに変換
175
- # apply_chat_templateは画像も処理できます
176
  inputs = processor.apply_chat_template(
177
- messages,
178
- add_generation_prompt=True,
179
- tokenize=True,
180
- return_dict=True, # 画像処理のために辞書形式で返すのが確実
181
- return_tensors="pt"
182
- ).to(model.device)
183
-
184
- # 入力トークン数の取得 (生成部分のみを後で抽出するため)
185
- # inputsが辞書の場合、'input_ids'キーを使用
186
- input_len = inputs['input_ids'].shape[-1]
187
-
188
- # 推論の実行
189
  with torch.inference_mode():
190
- outputs = model.generate(
191
- **inputs, # 辞書を展開して渡す
192
- max_new_tokens=512, # 最大生成トークン数
193
- do_sample=False # 画像説明などではFalseの方が安定することがあります
194
- )
195
-
196
- # 生成されたトークンのみをデコード
197
- # outputsはテンソルで返ってくる
198
- generated_tokens = outputs[0][input_len:]
199
- response = processor.decode(generated_tokens, skip_special_tokens=True)
200
-
201
- print("--- モデルの応答 ---")
202
- print(response)
203
  ```
204
 
205
  **注意点:**
 
24
  pip install -U transformers accelerate Pillow requests torch
25
  ```
26
 
27
+ ### 画像付き推論
 
 
 
 
 
 
 
 
28
 
29
  ```python
30
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
31
+ from PIL import Image
32
+ import requests
33
+ import torch
34
 
 
35
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.2"
36
 
37
+ model = Gemma3ForConditionalGeneration.from_pretrained(
38
+ model_id, device_map="auto"
39
+ ).eval()
40
 
41
+ processor = AutoProcessor.from_pretrained(model_id)
 
42
 
 
 
 
43
  messages = [
44
+ {
45
+ "role": "system",
46
+ "content": [{"type": "text", "text": "あなたは素晴らしい日本語アシスタントです。"}]
47
+ },
48
+ {
49
+ "role": "user",
50
+ "content": [
51
+ {"type": "image", "image": "https://cs.stanford.edu/people/rak248/VG_100K_2/2399540.jpg"},
52
+ {"type": "text", "text": "この画像を説明してください。"}
53
+ ]
54
+ }
55
  ]
56
 
57
+ inputs = processor.apply_chat_template(
58
+ messages, add_generation_prompt=True, tokenize=True,
59
+ return_dict=True, return_tensors="pt"
60
+ ).to(model.device, dtype=torch.bfloat16)
61
 
62
+ input_len = inputs["input_ids"].shape[-1]
 
63
 
64
+ with torch.inference_mode():
65
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
66
+ generation = generation[0][input_len:]
 
 
 
67
 
68
+ decoded = processor.decode(generation, skip_special_tokens=True)
69
+ print(decoded)
70
  ```
71
+ ### 画像無し推論
 
 
 
72
 
73
  ```python
74
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
75
  import torch
 
76
 
 
77
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.2"
78
 
79
+ model = Gemma3ForConditionalGeneration.from_pretrained(
80
+ model_id, device_map="auto"
81
+ ).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
 
83
  processor = AutoProcessor.from_pretrained(model_id)
84
+
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  messages = [
86
  {
87
  "role": "system",
88
+ "content": [{"type": "text", "text": "あなたは素晴らしい日本語アシスタントです。"}]
89
  },
90
  {
91
  "role": "user",
92
  "content": [
93
+ {"type": "text", "text": "AI言語モデルであるLaMDAが意識があることを主張して弁護士を呼んだとのことです。LaMDAには意識があると思いますか?"}
 
94
  ]
95
  }
96
  ]
97
 
 
 
98
  inputs = processor.apply_chat_template(
99
+ messages, add_generation_prompt=True, tokenize=True,
100
+ return_dict=True, return_tensors="pt"
101
+ ).to(model.device, dtype=torch.bfloat16)
102
+
103
+ input_len = inputs["input_ids"].shape[-1]
104
+
 
 
 
 
 
 
105
  with torch.inference_mode():
106
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
107
+ generation = generation[0][input_len:]
108
+
109
+ decoded = processor.decode(generation, skip_special_tokens=True)
110
+ print(decoded)
 
 
 
 
 
 
 
 
111
  ```
112
 
113
  **注意点:**