Spaces:
Configuration error
Configuration error
Update app.py
Browse files
app.py
CHANGED
@@ -90,6 +90,12 @@ def interact(user_input, history, interaction_count, model_name):
|
|
90 |
if tokenizer is None or model is None:
|
91 |
raise ValueError("Tokenizer or model is not initialized.")
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if interaction_count >= MAX_INTERACTIONS:
|
94 |
user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
|
95 |
|
@@ -102,11 +108,8 @@ def interact(user_input, history, interaction_count, model_name):
|
|
102 |
|
103 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
104 |
|
105 |
-
#
|
106 |
-
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
|
107 |
-
if model_name not in quantized_models:
|
108 |
-
input_ids.to("cuda")
|
109 |
-
|
110 |
chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
|
111 |
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
112 |
|
|
|
90 |
if tokenizer is None or model is None:
|
91 |
raise ValueError("Tokenizer or model is not initialized.")
|
92 |
|
93 |
+
# Determine the device to use (either CUDA if available, or CPU)
|
94 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
95 |
+
|
96 |
+
# Ensure the model is on the correct device
|
97 |
+
model.to(device)
|
98 |
+
|
99 |
if interaction_count >= MAX_INTERACTIONS:
|
100 |
user_input += ". Thank you for your questions. Our session is now over. Goodbye!"
|
101 |
|
|
|
108 |
|
109 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
110 |
|
111 |
+
# Move input tensor to the same device as the model
|
112 |
+
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
|
|
|
|
|
|
|
113 |
chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1)
|
114 |
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
115 |
|