Tri4 commited on
Commit
5ec5cb4
·
verified ·
1 Parent(s): 7c5a24d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -2
main.py CHANGED
@@ -21,12 +21,12 @@ model_id = "google/gemma-2-2b-it"
21
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
 
23
  # Load tokenizer and model with authentication token
24
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
  device_map="auto",
28
  torch_dtype=torch.float16,
29
- use_auth_token=HF_TOKEN
30
  )
31
 
32
  app_pipeline = pipeline(
@@ -44,6 +44,7 @@ def generate_text():
44
  temperature = data.get("temperature", 0.1)
45
  top_k = data.get("top_k", 50)
46
  top_p = data.get("top_p", 0.95)
 
47
 
48
  try:
49
  outputs = app_pipeline(
 
21
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
 
23
  # Load tokenizer and model with authentication token
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
  device_map="auto",
28
  torch_dtype=torch.float16,
29
+ token=HF_TOKEN
30
  )
31
 
32
  app_pipeline = pipeline(
 
44
  temperature = data.get("temperature", 0.1)
45
  top_k = data.get("top_k", 50)
46
  top_p = data.get("top_p", 0.95)
47
+ print(f"{prompt}: ")
48
 
49
  try:
50
  outputs = app_pipeline(