Spestly commited on
Commit
018f2bb
·
verified ·
1 Parent(s): eeda09f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -54,9 +54,13 @@ def generate_response(model_id, conversation, user_message, max_length=512, temp
54
  # Create prompt
55
  prompt = "\n".join(conversation_history)
56
 
57
- # Tokenize and generate
58
  inputs = tokenizer(prompt, return_tensors="pt")
59
 
 
 
 
 
60
  generation_start = time.time()
61
  with torch.no_grad():
62
  outputs = model.generate(
 
54
  # Create prompt
55
  prompt = "\n".join(conversation_history)
56
 
57
+ # Tokenize and move to GPU
58
  inputs = tokenizer(prompt, return_tensors="pt")
59
 
60
+ # Move inputs to the same device as the model
61
+ device = next(model.parameters()).device
62
+ inputs = {k: v.to(device) for k, v in inputs.items()}
63
+
64
  generation_start = time.time()
65
  with torch.no_grad():
66
  outputs = model.generate(