bwilkie commited on
Commit
455866a
·
verified ·
1 Parent(s): 9f08c4f

Update myagent.py

Browse files
Files changed (1) hide show
  1. myagent.py +8 -8
myagent.py CHANGED
@@ -5,6 +5,7 @@ from tools.fetch import fetch_webpage
5
  from tools.yttranscript import get_youtube_transcript, get_youtube_title_description
6
  import myprompts
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
8
  import torch
9
  # --- Basic Agent Definition ---
10
  class BasicAgent:
@@ -49,18 +50,17 @@ class BasicAgent:
49
  # )
50
 
51
 
52
- MODEL_NAME = "meta-llama/Llama-3.2-3B" # 3B isn't released by Meta officially, but use 8B or a 3B variant like TinyLlama if needed
 
 
 
 
 
53
 
54
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
55
- model_init = AutoModelForCausalLM.from_pretrained(
56
- MODEL_NAME,
57
- device_map="auto",
58
- torch_dtype=torch.float16 # or bfloat16
59
- )
60
 
61
  def model(prompt: str, max_new_tokens=512):
62
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
63
- output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens)
64
  output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
  return output
66
 
 
5
  from tools.yttranscript import get_youtube_transcript, get_youtube_title_description
6
  import myprompts
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
+
9
  import torch
10
  # --- Basic Agent Definition ---
11
  class BasicAgent:
 
50
  # )
51
 
52
 
53
+ model_id = "bartowski/Llama-3.2-3B-Instruct-GGUF"
54
+ filename = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
55
+
56
+ torch_dtype = torch.float32 # could be torch.float16 or torch.bfloat16 too
57
+ tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename)
58
+ model_init = AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename, torch_dtype=torch_dtype)
59
 
 
 
 
 
 
 
60
 
61
  def model(prompt: str, max_new_tokens=512):
62
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
63
+ output_ids = model_init.generate(input_ids, max_new_tokens=max_new_tokens)
64
  output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
  return output
66