bwilkie commited on
Commit
391fe34
·
verified ·
1 Parent(s): a94813f

Update myagent.py

Browse files
Files changed (1) hide show
  1. myagent.py +13 -18
myagent.py CHANGED
@@ -41,7 +41,15 @@ class BasicAgent:
41
  return error
42
 
43
 
44
-
 
 
 
 
 
 
 
 
45
 
46
  # Create a wrapper class that matches the expected interface
47
  class LocalLlamaModel:
@@ -50,23 +58,10 @@ class LocalLlamaModel:
50
  self.tokenizer = tokenizer
51
  self.device = model.device if hasattr(model, 'device') else 'cpu'
52
 
53
- def generate(self, prompt: str, max_new_tokens=512, **kwargs):
54
- """Generate text using the local model"""
55
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
56
-
57
- with torch.no_grad():
58
- output_ids = self.model.generate(
59
- input_ids,
60
- max_new_tokens=max_new_tokens,
61
- do_sample=True,
62
- temperature=0.7,
63
- pad_token_id=self.tokenizer.eos_token_id,
64
- **kwargs
65
- )
66
-
67
- # Decode only the new tokens (excluding the input)
68
- new_tokens = output_ids[0][input_ids.shape[1]:]
69
- output = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
70
  return output
71
 
72
  def __call__(self, prompt: str, max_new_tokens=512, **kwargs):
 
41
  return error
42
 
43
 
44
+
45
+ # Model configuration
46
+ model_id = "bartowski/Llama-3.2-3B-Instruct-GGUF"
47
+ filename = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
48
+
49
+ # Load tokenizer and model
50
+ tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename)
51
+ model_init = AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename, torch_dtype=torch_dtype)
52
+
53
 
54
  # Create a wrapper class that matches the expected interface
55
  class LocalLlamaModel:
 
58
  self.tokenizer = tokenizer
59
  self.device = model.device if hasattr(model, 'device') else 'cpu'
60
 
61
+ def generate(self, prompt: str, max_new_tokens=512*10, **kwargs):
62
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
63
+ output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens)
64
+ output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  return output
66
 
67
  def __call__(self, prompt: str, max_new_tokens=512, **kwargs):