|
import keras |
|
import keras_hub |
|
|
|
model_presets = [ |
|
|
|
"hf://google/gemma-2-instruct-9b-keras", |
|
"hf://meta-llama/Llama-3.1-8B-Instruct", |
|
"hf://google/codegemma-7b-it-keras", |
|
"hf://keras/mistral_instruct_7b_en", |
|
"hf://keras/vicuna_1.5_7b_en", |
|
|
|
|
|
"hf://meta-llama/Llama-3.2-1B-Instruct", |
|
"hf://google/gemma-2b-it-keras", |
|
"hf://meta-llama/Llama-3.2-3B-Instruct", |
|
] |
|
|
|
model_labels = map(lambda s: s.removeprefix("hf://"), model_presets) |
|
model_labels = map(lambda s: s.removeprefix("google/"), model_labels) |
|
model_labels = map(lambda s: s.removeprefix("keras/"), model_labels) |
|
model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels) |
|
|
|
|
|
def preset_to_website_url(preset): |
|
preset = preset.removeprefix("hf://") |
|
url = "http://huggingface.co/" + preset |
|
return url |
|
|
|
|
|
def get_appropriate_chat_template(preset): |
|
return "Vicuna" if "vicuna" in preset else "auto" |
|
|
|
|
|
def get_default_layout_map(preset_name, device_mesh): |
|
|
|
|
|
if ( |
|
"Llama" in preset_name |
|
or "mistral" in preset_name |
|
or "vicuna" in preset_name |
|
): |
|
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh) |
|
|
|
|
|
layout_map["token_embedding/reverse_embeddings"] = ("batch", "model") |
|
return layout_map |
|
|
|
elif "gemma" in preset_name: |
|
layout_map = keras_hub.models.GemmaBackbone.get_layout_map(device_mesh) |
|
|
|
if "gemma-2b-" in preset_name: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patch_key = "decoder_block.*attention.*(query|key|value).kernel" |
|
layout_map.pop(patch_key) |
|
layout_map[patch_key] = (None, "model", "batch") |
|
|
|
return layout_map |
|
|
|
|
|
def log_applied_layout_map(model): |
|
print("Model class:", type(model).__name__) |
|
|
|
if "Gemma" in type(model).__name__: |
|
transformer_decoder_block_name = "decoder_block_1" |
|
elif "Llama" in type(model).__name__: |
|
transformer_decoder_block_name = "transformer_layer_1" |
|
elif "Mistral" in type(model).__name__: |
|
transformer_decoder_block_name = "transformer_layer_1" |
|
else: |
|
print("Unknown architecture. Cannot display the applied layout.") |
|
return |
|
|
|
|
|
embedding_layer = model.backbone.get_layer("token_embedding") |
|
print(embedding_layer) |
|
decoder_block = model.backbone.get_layer(transformer_decoder_block_name) |
|
print(type(decoder_block)) |
|
for variable in embedding_layer.weights + decoder_block.weights: |
|
print( |
|
f"{variable.path:<58} \ |
|
{str(variable.shape):<16} \ |
|
{str(variable.value.sharding.spec):<35} \ |
|
{str(variable.dtype)}" |
|
) |
|
|
|
|
|
def load_model(preset): |
|
devices = keras.distribution.list_devices() |
|
device_mesh = keras.distribution.DeviceMesh( |
|
shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices |
|
) |
|
model_parallel = keras.distribution.ModelParallel( |
|
layout_map=get_default_layout_map(preset, device_mesh), |
|
batch_dim_name="batch", |
|
) |
|
|
|
with model_parallel.scope(): |
|
|
|
if "google/gemma-2-instruct-9b-keras" in preset: |
|
model = keras_hub.models.GemmaCausalLM( |
|
backbone=keras_hub.models.GemmaBackbone.from_preset( |
|
preset, dtype="bfloat16" |
|
), |
|
preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset( |
|
preset |
|
), |
|
) |
|
elif "meta-llama/Llama-3.1-8B-Instruct" in preset: |
|
model = keras_hub.models.Llama3CausalLM( |
|
backbone=keras_hub.models.Llama3Backbone.from_preset( |
|
preset, dtype="bfloat16" |
|
), |
|
preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset( |
|
preset |
|
), |
|
) |
|
else: |
|
model = keras_hub.models.CausalLM.from_preset( |
|
preset, dtype="bfloat16" |
|
) |
|
|
|
log_applied_layout_map(model) |
|
return model |
|
|