joaogante HF staff commited on
Commit
4b1ae14
·
verified ·
1 Parent(s): 35c2ab9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -8,8 +8,8 @@ import torch
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
9
 
10
 
11
- model_id = "google/gemma-2-27b-it"
12
- assistant_id = "google/gemma-2-2b-it"
13
 
14
  model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
15
  assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(device=model.device, dtype=torch.float16)
@@ -61,9 +61,9 @@ def reset_textbox():
61
  with gr.Blocks() as demo:
62
  gr.Markdown(
63
  "# 🤗 Assisted Generation Demo\n"
64
- f"- Model: {model_id} (4-bit quant, 14B params, GPU memory = ~7GB)\n"
65
- f"- Assistant Model: {assistant_id} (FP16, 0.5B params, GPU memory = ~1GB)\n"
66
- "- Recipe for speedup: a) >10x model size difference in parameters; b) assistant trained similarly; c) CPU is not a bottleneck"
67
  )
68
 
69
  with gr.Row():
 
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
9
 
10
 
11
+ model_id = "meta-llama/Llama-3.1-8B"
12
+ assistant_id = "meta-llama/Llama-3.2-1B"
13
 
14
  model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
15
  assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(device=model.device, dtype=torch.float16)
 
61
  with gr.Blocks() as demo:
62
  gr.Markdown(
63
  "# 🤗 Assisted Generation Demo\n"
64
+ f"- Model: {model_id}\n"
65
+ f"- Assistant Model: {assistant_id}\n"
66
+ "- Recipe for good speedup: a) >10x model size difference in parameters; b) assistant trained similarly; c) CPU is not a bottleneck"
67
  )
68
 
69
  with gr.Row():