lvwerra HF staff commited on
Commit
3c06e38
·
1 Parent(s): 39e12e7

refactor pipeline loading

Browse files
Files changed (1) hide show
  1. app.py +3 -8
app.py CHANGED
@@ -4,13 +4,10 @@ from transformers import pipeline
4
  import torch
5
  import json
6
 
7
- @st.cache(allow_output_mutation=True)
8
- def load_tokenizer(model_ckpt):
9
- return AutoTokenizer.from_pretrained(model_ckpt)
10
 
11
  @st.cache(allow_output_mutation=True)
12
- def load_model(model_ckpt):
13
- model = AutoModelForCausalLM.from_pretrained(model_ckpt)
14
  return model
15
 
16
  @st.cache()
@@ -24,8 +21,7 @@ st.set_page_config(page_icon=':parrot:', layout="wide")
24
  device = 1 if torch.cuda.is_available() else 0
25
  device_name = "GPU" if torch.cuda.is_available() else "CPU"
26
  model_ckpt = "lvwerra/codeparrot"
27
- tokenizer = load_tokenizer(model_ckpt)
28
- model = load_model(model_ckpt)
29
  examples = load_examples()
30
  example_names = [example["name"] for example in examples]
31
  name2id = dict([(name, i) for i, name in enumerate(example_names)])
@@ -35,7 +31,6 @@ gen_kwargs = {}
35
  st.title("CodeParrot 🦜")
36
  st.markdown('##')
37
 
38
- pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, device=device)
39
  st.sidebar.header("Examples:")
40
  selected_example = st.sidebar.selectbox("Select one of the following examples:", example_names)
41
  example_text = examples[name2id[selected_example]]["value"]
 
4
  import torch
5
  import json
6
 
 
 
 
7
 
8
  @st.cache(allow_output_mutation=True)
9
+ def load_pipeline(model_ckpt, device):
10
+ pipe = pipeline('text-generation', model=model_ckpt, device=device)
11
  return model
12
 
13
  @st.cache()
 
21
  device = 1 if torch.cuda.is_available() else 0
22
  device_name = "GPU" if torch.cuda.is_available() else "CPU"
23
  model_ckpt = "lvwerra/codeparrot"
24
+ pipe = load_pipe(model_ckpt, device)
 
25
  examples = load_examples()
26
  example_names = [example["name"] for example in examples]
27
  name2id = dict([(name, i) for i, name in enumerate(example_names)])
 
31
  st.title("CodeParrot 🦜")
32
  st.markdown('##')
33
 
 
34
  st.sidebar.header("Examples:")
35
  selected_example = st.sidebar.selectbox("Select one of the following examples:", example_names)
36
  example_text = examples[name2id[selected_example]]["value"]