nisten commited on
Commit
adaf527
·
verified ·
1 Parent(s): 3d92619

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import OlmoeForCausalLM, AutoTokenizer
4
  import torch
5
  import subprocess
6
  import sys
 
7
 
8
  # Force upgrade transformers to the latest version
9
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
@@ -36,25 +37,23 @@ def generate_response(message, history, temperature, max_new_tokens):
36
  if model is None or tokenizer is None:
37
  return "Model or tokenizer not loaded properly. Please check the logs."
38
 
39
- messages = [{"role": "system", "content": system_prompt},
40
- {"role": "user", "content": message}]
41
-
42
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
43
 
44
  with torch.no_grad():
45
  generate_ids = model.generate(
46
- inputs,
47
  max_new_tokens=max_new_tokens,
48
  do_sample=True,
49
  temperature=temperature,
50
  eos_token_id=tokenizer.eos_token_id,
51
  )
52
- response = tokenizer.decode(generate_ids[0, inputs.shape[1]:], skip_special_tokens=True)
53
  return response.strip()
54
 
55
  css = """
56
  #output {
57
- height: 900px;
58
  overflow: auto;
59
  border: 1px solid #ccc;
60
  }
@@ -85,4 +84,4 @@ with gr.Blocks(css=css) as demo:
85
 
86
  if __name__ == "__main__":
87
  demo.queue(api_open=False)
88
- demo.launch(debug=True, show_api=True, share=True)
 
4
  import torch
5
  import subprocess
6
  import sys
7
+ import os
8
 
9
  # Force upgrade transformers to the latest version
10
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
 
37
  if model is None or tokenizer is None:
38
  return "Model or tokenizer not loaded properly. Please check the logs."
39
 
40
+ messages = [{"role": "user", "content": message}]
 
 
41
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
42
 
43
  with torch.no_grad():
44
  generate_ids = model.generate(
45
+ **inputs,
46
  max_new_tokens=max_new_tokens,
47
  do_sample=True,
48
  temperature=temperature,
49
  eos_token_id=tokenizer.eos_token_id,
50
  )
51
+ response = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
52
  return response.strip()
53
 
54
  css = """
55
  #output {
56
+ height: 500px;
57
  overflow: auto;
58
  border: 1px solid #ccc;
59
  }
 
84
 
85
  if __name__ == "__main__":
86
  demo.queue(api_open=False)
87
+ demo.launch(debug=True, show_api=False, share=True )