GGLS commited on
Commit
39e5383
·
verified ·
1 Parent(s): 13ea389

Upload app.py

Browse files

load model in 4bit for faster infer.

Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -31,7 +31,7 @@ def load_model(model_name):
31
  model = AutoModelForCausalLM.from_pretrained(
32
  f"{root_path}/{model_name}",
33
  device_map="auto",
34
- load_in_8bit=True,
35
  torch_dtype=torch.bfloat16,
36
  trust_remote_code=True,
37
  )
@@ -47,7 +47,7 @@ with st.sidebar:
47
  top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
48
  top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=50, step=1)
49
  repetition_penalty = st.sidebar.slider('repetition penalty', min_value=1., max_value=2., value=1.2, step=0.05)
50
- max_length = st.sidebar.slider('max new tokens', min_value=32, max_value=2000, value=240, step=8)
51
 
52
  with st.spinner('loading model..'):
53
  model, tokenizer = load_model(model_name)
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  f"{root_path}/{model_name}",
33
  device_map="auto",
34
+ load_in_4bit=True,
35
  torch_dtype=torch.bfloat16,
36
  trust_remote_code=True,
37
  )
 
47
  top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
48
  top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=50, step=1)
49
  repetition_penalty = st.sidebar.slider('repetition penalty', min_value=1., max_value=2., value=1.2, step=0.05)
50
+ max_length = st.sidebar.slider('max new tokens', min_value=32, max_value=2000, value=512, step=8)
51
 
52
  with st.spinner('loading model..'):
53
  model, tokenizer = load_model(model_name)