joaogante HF staff commited on
Commit
588b2d4
·
1 Parent(s): a1a543e

update model to pythia

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -4,7 +4,8 @@ import torch
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
6
 
7
- model_id = "declare-lab/flan-alpaca-large"
 
8
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
9
  print("Running on device:", torch_device)
10
  print("CPU threads:", torch.get_num_threads())
@@ -15,6 +16,7 @@ if torch_device == "cuda":
15
  else:
16
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
18
 
19
 
20
  def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
@@ -49,15 +51,10 @@ def reset_textbox():
49
 
50
 
51
  with gr.Blocks() as demo:
52
- duplicate_link = "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
53
  gr.Markdown(
54
- "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
55
- "This demo showcases the use of the "
56
- "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
57
- "of 🤗 Transformers with Gradio to generate text in real-time. It uses "
58
- f"[{model_id}](https://huggingface.co/{model_id}) and the Spaces free compute tier.\n\n"
59
- f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
60
- "template! 💛"
61
  )
62
 
63
  with gr.Row():
 
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
6
 
7
+ model_id = "EleutherAI/pythia-6.9b-deduped"
8
+ assistant_id = "EleutherAI/pythia-70m-deduped"
9
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
10
  print("Running on device:", torch_device)
11
  print("CPU threads:", torch.get_num_threads())
 
16
  else:
17
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
18
  tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+ assistant_model = AutoModelForSeq2SeqLM.from_pretrained(assistant_id).to(torch_device)
20
 
21
 
22
  def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
 
51
 
52
 
53
  with gr.Blocks() as demo:
 
54
  gr.Markdown(
55
+ "# 🤗 Assisted Generation Demo\n"
56
+ f"Model: {model_id}\n"
57
+ f"Assistant Model: {assistant_id}"
 
 
 
 
58
  )
59
 
60
  with gr.Row():