Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from transformers import OlmoeForCausalLM, AutoTokenizer
|
|
4 |
import torch
|
5 |
import subprocess
|
6 |
import sys
|
|
|
7 |
|
8 |
# Force upgrade transformers to the latest version
|
9 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
|
@@ -36,25 +37,23 @@ def generate_response(message, history, temperature, max_new_tokens):
|
|
36 |
if model is None or tokenizer is None:
|
37 |
return "Model or tokenizer not loaded properly. Please check the logs."
|
38 |
|
39 |
-
messages = [{"role": "
|
40 |
-
{"role": "user", "content": message}]
|
41 |
-
|
42 |
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
|
43 |
|
44 |
with torch.no_grad():
|
45 |
generate_ids = model.generate(
|
46 |
-
inputs,
|
47 |
max_new_tokens=max_new_tokens,
|
48 |
do_sample=True,
|
49 |
temperature=temperature,
|
50 |
eos_token_id=tokenizer.eos_token_id,
|
51 |
)
|
52 |
-
response = tokenizer.decode(generate_ids[0
|
53 |
return response.strip()
|
54 |
|
55 |
css = """
|
56 |
#output {
|
57 |
-
height:
|
58 |
overflow: auto;
|
59 |
border: 1px solid #ccc;
|
60 |
}
|
@@ -85,4 +84,4 @@ with gr.Blocks(css=css) as demo:
|
|
85 |
|
86 |
if __name__ == "__main__":
|
87 |
demo.queue(api_open=False)
|
88 |
-
demo.launch(debug=True, show_api=
|
|
|
4 |
import torch
|
5 |
import subprocess
|
6 |
import sys
|
7 |
+
import os
|
8 |
|
9 |
# Force upgrade transformers to the latest version
|
10 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
|
|
|
37 |
if model is None or tokenizer is None:
|
38 |
return "Model or tokenizer not loaded properly. Please check the logs."
|
39 |
|
40 |
+
messages = [{"role": "user", "content": message}]
|
|
|
|
|
41 |
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
|
42 |
|
43 |
with torch.no_grad():
|
44 |
generate_ids = model.generate(
|
45 |
+
**inputs,
|
46 |
max_new_tokens=max_new_tokens,
|
47 |
do_sample=True,
|
48 |
temperature=temperature,
|
49 |
eos_token_id=tokenizer.eos_token_id,
|
50 |
)
|
51 |
+
response = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
|
52 |
return response.strip()
|
53 |
|
54 |
css = """
|
55 |
#output {
|
56 |
+
height: 500px;
|
57 |
overflow: auto;
|
58 |
border: 1px solid #ccc;
|
59 |
}
|
|
|
84 |
|
85 |
if __name__ == "__main__":
|
86 |
demo.queue(api_open=False)
|
87 |
+
demo.launch(debug=True, show_api=False, share=True )
|