Akshat1000 commited on
Commit
9c31c2b
·
verified ·
1 Parent(s): cb1218c

Update getans.py

Browse files
Files changed (1) hide show
  1. getans.py +15 -14
getans.py CHANGED
@@ -1,14 +1,15 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
-
4
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
-
6
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
7
- model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
8
-
9
-
10
- def get_response(prompt, max_new_tokens=50):
11
- inputs = tokenizer(prompt, return_tensors="pt")
12
- outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature=0.0001, do_sample=True)
13
- response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Use indexing instead of calling
14
- return response
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
7
+ model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
8
+
9
+ model.to(device)
10
+
11
+ def get_response(prompt, max_new_tokens=50):
12
+ inputs = tokenizer(prompt, return_tensors="pt")
13
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature=0.0001, do_sample=True)
14
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Use indexing instead of calling
15
+ return response