Holy-fox commited on
Commit
4ca6aca
·
verified ·
1 Parent(s): 04bb044

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +50 -180
README.md CHANGED
@@ -13,227 +13,97 @@ license: gemma
13
 
14
  ## How to use
15
 
16
- **注意:** 以下のコードを実行する前に、必要なライブラリをインストールしてください。特に `transformers` ライブラリは Gemma 3 をサポートするバージョン (4.50.0 以降) が必要です。また、Unsloth を使用してファインチューニングされたモデルの場合、推論時にも Unsloth が必要になる場合があります。
17
 
18
  ```sh
19
  pip install -U transformers accelerate torch
20
- # vLLM を使用する場合
21
- pip install vllm
22
- # Unsloth が推論に必要となる場合
23
- pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" # 環境に合わせて調整
24
  ```
25
 
26
- ### vLLM での推論 (テキスト生成)
27
-
28
- vLLM を使用すると、高速なテキスト生成推論が可能です。(2025年3月現在、vLLMのGemma 3マルチモーダル対応は進行中の可能性があります。最新情報はvLLMのドキュメントをご確認ください。)
29
 
30
  ```python
31
- from vllm import LLM, SamplingParams
32
-
33
- # モデル名を指定
34
- model_name = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
35
- # またはローカルパスを指定
36
- # model_name = "/path/to/your/model"
37
-
38
- # LLMインスタンスを作成
39
- # tensor_parallel_size は利用可能なGPU数に合わせて調整してください
40
- llm = LLM(model=model_name, trust_remote_code=True) # Unslothモデルの場合など必要に応じて trust_remote_code=True
41
-
42
- # サンプリングパラメータを設定
43
- sampling_params = SamplingParams(temperature=0.1, top_p=0.95, max_tokens=200)
44
-
45
- prompt = "<start_of_turn>user\n日本の首都はどこですか?<end_of_turn>\n<start_of_turn>model\n"
46
-
47
- # 推論を実行
48
- outputs = llm.generate(prompt, sampling_params)
49
-
50
- # 結果を表示
51
- for output in outputs:
52
- prompt = output.prompt
53
- generated_text = output.outputs[0].text
54
- print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
55
-
56
- ```
57
 
58
- ### Transformers での推論 (テキストのみ)
59
 
60
- `transformers` ライブラリを使用して、テキストプロンプト(システムプロンプトとユーザープロンプトを含む)に基づいてテキストを生成します。
 
 
61
 
62
- ```python
63
- from transformers import pipeline, AutoTokenizer
64
- import torch
65
 
66
- # モデル名とトークナイザーを指定
67
- model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
68
- tokenizer = AutoTokenizer.from_pretrained(model_id)
69
-
70
- # パイプラインを作成
71
- pipe = pipeline(
72
- "text-generation", # Gemma 3 のテキスト生成には text-generation が適切
73
- model=model_id,
74
- tokenizer=tokenizer, # 明示的にトークナイザーを渡す
75
- device="cuda", # GPUが利用可能な場合
76
- torch_dtype=torch.bfloat16 # Gemma 3 推奨のデータ型
77
- )
78
-
79
- # チャット形式のメッセージを作成
80
  messages = [
81
  {
82
  "role": "system",
83
- "content": "あなたは親切なアシスタントです。" # システムプロンプト
84
  },
85
  {
86
  "role": "user",
87
- "content": "Unslothとは何ですか?簡単に説明してください。" # ユーザープロンプト
 
 
 
88
  }
89
  ]
90
 
91
- # チャットテンプレートを適用
92
- # apply_chat_template は内部で <start_of_turn>などを付与します
93
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
94
-
95
- # 推論を実行
96
- # max_new_tokens は生成する最大トークン数
97
- # do_sample=True にすると、多様な応答が生成されやすくなります
98
- outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.2, top_p=0.95)
99
-
100
- # 生成されたテキストのみを表示 (入力プロンプト部分を除く)
101
- generated_text = outputs[0]['generated_text'][len(prompt):]
102
- print(generated_text)
103
-
104
- # --- AutoModelForCausalLM を使う場合 ---
105
- # from transformers import AutoModelForCausalLM
106
-
107
- # model = AutoModelForCausalLM.from_pretrained(
108
- # model_id,
109
- # torch_dtype=torch.bfloat16,
110
- # device_map="auto", # GPUに自動で配置
111
- # # Unslothモデルの場合、追加の引数が必要な��合があります
112
- # )
113
- # model.eval()
114
-
115
- # inputs = tokenizer.apply_chat_template(
116
- # messages,
117
- # add_generation_prompt=True,
118
- # return_tensors="pt"
119
- # ).to(model.device)
120
-
121
- # input_len = inputs.shape[-1]
122
-
123
- # with torch.inference_mode():
124
- # generation_output = model.generate(
125
- # inputs,
126
- # max_new_tokens=256,
127
- # do_sample=True,
128
- # temperature=0.7,
129
- # top_p=0.95,
130
- # )
131
- # # 入力部分を除いた生成トークンを取得
132
- # generated_tokens = generation_output[0][input_len:]
133
- # decoded = tokenizer.decode(generated_tokens, skip_special_tokens=True)
134
- # print(decoded)
135
- ```
136
 
137
- ### Transformers での推論 (画像とテキスト)
 
 
138
 
139
- `transformers` ライブラリを使用して、画像とテキストプロンプトに基づいてテキストを生成します。
 
 
 
140
 
141
  ```python
142
- from transformers import pipeline, AutoProcessor
143
  import torch
144
- from PIL import Image
145
- import requests
146
 
147
- # モデル名、プロセッサーを指定
148
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
 
 
 
 
 
149
  processor = AutoProcessor.from_pretrained(model_id)
150
 
151
- # パイプラインを作成 (image-text-to-textタスク)
152
- pipe = pipeline(
153
- "image-text-to-text",
154
- model=model_id,
155
- processor=processor, # 明示的にプロセッサーを渡す
156
- device="cuda", # GPUが利用可能な場合
157
- torch_dtype=torch.bfloat16 # Gemma 3 推奨のデータ型
158
- # Unslothモデルの場合、追加の引数が必要な場合があります
159
- )
160
-
161
- # 画像のURL
162
- image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
163
- # 画像を読み込む (ローカルファイルの場合は Image.open("path/to/image.jpg") )
164
- image = Image.open(requests.get(image_url, stream=True).raw)
165
-
166
- # チャット形式のメッセージを作成 (画像とテキストを含む)
167
  messages = [
168
  {
169
  "role": "system",
170
- "content": [{"type": "text", "text": "You are a helpful assistant."}]
171
  },
172
  {
173
  "role": "user",
174
  "content": [
175
- {"type": "image"}, # 画像のプレースホルダー
176
- {"type": "text", "text": "この画像について詳しく説明してください。"} # テキストプロンプト
177
  ]
178
  }
179
  ]
180
 
181
- # 推論を実行 (images引数で画像を渡す)
182
- # max_new_tokens は生成する最大トークン数
183
- outputs = pipe(messages, images=image, max_new_tokens=200)
184
-
185
- # 生成されたテキストを表示
186
- # パイプラインの出力形式に合わせて調整が必要な場合があります
187
- # Gemma 3の場合、最後のメッセージのcontentを取り出すことが多いです
188
- print(outputs[0]["generated_text"][-1]["content"])
189
-
190
- # --- Gemma3ForConditionalGeneration を使う場合 ---
191
- # from transformers import Gemma3ForConditionalGeneration
192
-
193
- # model = Gemma3ForConditionalGeneration.from_pretrained(
194
- # model_id,
195
- # torch_dtype=torch.bfloat16,
196
- # device_map="auto" # GPUに自動で配置
197
- # # Unslothモデルの場合、追加の引数が必要な場合があります
198
- # ).eval()
199
-
200
- # # 画像を含むメッセージを作成 (Imageオブジェクトを直接渡す)
201
- # messages_for_processor = [
202
- # {
203
- # "role": "system",
204
- # "content": [{"type": "text", "text": "You are a helpful assistant."}]
205
- # },
206
- # {
207
- # "role": "user",
208
- # "content": [
209
- # {"type": "image", "image": image}, # PIL Image オブジェクト
210
- # {"type": "text", "text": "この画像について詳しく説明してください。"}
211
- # ]
212
- # }
213
- # ]
214
-
215
- # # プロセッサーで入力を作成
216
- # inputs = processor.apply_chat_template(
217
- # messages_for_processor,
218
- # add_generation_prompt=True,
219
- # tokenize=True, # トークン化を有効に
220
- # return_dict=True,
221
- # return_tensors="pt"
222
- # ).to(model.device) # モデルと同じデバイスに移動
223
-
224
- # input_len = inputs["input_ids"].shape[-1]
225
-
226
- # # 推論実行
227
- # with torch.inference_mode():
228
- # generation = model.generate(**inputs, max_new_tokens=200, do_sample=False)
229
- # # 入力部分を除いた生成トークンを取得
230
- # generation = generation[0][input_len:]
231
-
232
- # # デコードして表示
233
- # decoded = processor.decode(generation, skip_special_tokens=True)
234
- # print(decoded)
235
- ```
236
 
 
 
 
 
 
 
 
 
 
237
  ## License
238
 
239
  このモデルは、ベースモデルである `google/gemma-3-4b-it` のライセンス条件に従います。詳細については、以下のリンクをご参照ください。
 
13
 
14
  ## How to use
15
 
16
+ **注意:** 以下のコードを実行する前に、必要なライブラリをインストールしてください。特に `transformers` ライブラリは Gemma 3 をサポートするバージョン (4.50.0 以降) が必要です。
17
 
18
  ```sh
19
  pip install -U transformers accelerate torch
 
 
 
 
20
  ```
21
 
22
+ ### 画像付き推論
 
 
23
 
24
  ```python
25
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
26
+ from PIL import Image
27
+ import requests
28
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
31
 
32
+ model = Gemma3ForConditionalGeneration.from_pretrained(
33
+ model_id, device_map="auto"
34
+ ).eval()
35
 
36
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  messages = [
39
  {
40
  "role": "system",
41
+ "content": [{"type": "text", "text": "あなたは素晴らしい日本語アシスタントです。"}]
42
  },
43
  {
44
  "role": "user",
45
+ "content": [
46
+ {"type": "image", "image": "https://cs.stanford.edu/people/rak248/VG_100K_2/2399540.jpg"},
47
+ {"type": "text", "text": "この画像を説明してください。"}
48
+ ]
49
  }
50
  ]
51
 
52
+ inputs = processor.apply_chat_template(
53
+ messages, add_generation_prompt=True, tokenize=True,
54
+ return_dict=True, return_tensors="pt"
55
+ ).to(model.device, dtype=torch.bfloat16)
56
+
57
+ input_len = inputs["input_ids"].shape[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ with torch.inference_mode():
60
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
61
+ generation = generation[0][input_len:]
62
 
63
+ decoded = processor.decode(generation, skip_special_tokens=True)
64
+ print(decoded)
65
+ ```
66
+ ### 画像無し推論
67
 
68
  ```python
69
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
70
  import torch
 
 
71
 
 
72
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-instruct-v0.1"
73
+
74
+ model = Gemma3ForConditionalGeneration.from_pretrained(
75
+ model_id, device_map="auto"
76
+ ).eval()
77
+
78
  processor = AutoProcessor.from_pretrained(model_id)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  messages = [
81
  {
82
  "role": "system",
83
+ "content": [{"type": "text", "text": "あなたは素晴らしい日本語アシスタントです。"}]
84
  },
85
  {
86
  "role": "user",
87
  "content": [
88
+ {"type": "text", "text": "GPT3やGPT3.5などと比べてGPT4はどこがすごいのでしょうか?"}
 
89
  ]
90
  }
91
  ]
92
 
93
+ inputs = processor.apply_chat_template(
94
+ messages, add_generation_prompt=True, tokenize=True,
95
+ return_dict=True, return_tensors="pt"
96
+ ).to(model.device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ input_len = inputs["input_ids"].shape[-1]
99
+
100
+ with torch.inference_mode():
101
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
102
+ generation = generation[0][input_len:]
103
+
104
+ decoded = processor.decode(generation, skip_special_tokens=True)
105
+ print(decoded)
106
+ ```
107
  ## License
108
 
109
  このモデルは、ベースモデルである `google/gemma-3-4b-it` のライセンス条件に従います。詳細については、以下のリンクをご参照ください。