bwilkie commited on
Commit
e4f7d1f
·
verified ·
1 Parent(s): 9ff076a

Update myagent.py

Browse files
Files changed (1) hide show
  1. myagent.py +15 -18
myagent.py CHANGED
@@ -49,11 +49,10 @@ model = AutoModelForCausalLM.from_pretrained(
49
  device_map="auto",
50
  torch_dtype="bfloat16",
51
  trust_remote_code=True,
52
- # attn_implementation="flash_attention_2" <- uncomment on compatible GPU
53
  )
54
  tokenizer = AutoTokenizer.from_pretrained(model_id)
55
 
56
-
57
  # Create a wrapper class that matches the expected interface
58
  class LocalLlamaModel:
59
  def __init__(self, model, tokenizer):
@@ -61,29 +60,25 @@ class LocalLlamaModel:
61
  self.tokenizer = tokenizer
62
  self.device = model.device if hasattr(model, 'device') else 'cpu'
63
 
64
- def generate(self, prompt: str, max_new_tokens=512*10, **kwargs):
65
-
66
-
67
- # Generate answer
68
- prompt = "What is C. elegans?"
69
- input_ids = tokenizer.apply_chat_template(
70
  [{"role": "user", "content": prompt}],
71
  add_generation_prompt=True,
72
  return_tensors="pt",
73
  tokenize=True,
74
- ).to(model.device)
75
 
76
- output = model.generate(
77
  input_ids,
78
  do_sample=True,
79
  temperature=0.3,
80
  min_p=0.15,
81
  repetition_penalty=1.05,
82
- max_new_tokens=512,
83
  )
84
 
85
- output =tokenizer.decode(output[0], skip_special_tokens=False)
86
-
87
  return output
88
 
89
  def __call__(self, prompt: str, max_new_tokens=512, **kwargs):
@@ -91,16 +86,18 @@ class LocalLlamaModel:
91
  return self.generate(prompt, max_new_tokens, **kwargs)
92
 
93
  # Create the model instance
94
- model = LocalLlamaModel(model_init, tokenizer)
95
 
96
  # Now create your agents - these should work with the wrapped model
97
- reviewer_agent = ToolCallingAgent(model=model, tools=[])
98
- model_agent = ToolCallingAgent(model=model, tools=[fetch_webpage])
99
  gaia_agent = CodeAgent(
100
- tools=[fetch_webpage, get_youtube_title_description, get_youtube_transcript],
101
- model=model
102
  )
103
 
 
 
104
  if __name__ == "__main__":
105
  # Example usage
106
  question = "What was the actual enrollment of the Malko competition in 2023?"
 
49
  device_map="auto",
50
  torch_dtype="bfloat16",
51
  trust_remote_code=True,
52
+ # attn_implementation="flash_attention_2" # <- uncomment on compatible GPU
53
  )
54
  tokenizer = AutoTokenizer.from_pretrained(model_id)
55
 
 
56
  # Create a wrapper class that matches the expected interface
57
  class LocalLlamaModel:
58
  def __init__(self, model, tokenizer):
 
60
  self.tokenizer = tokenizer
61
  self.device = model.device if hasattr(model, 'device') else 'cpu'
62
 
63
+ def generate(self, prompt: str, max_new_tokens=512, **kwargs):
64
+ # Generate answer using the provided prompt
65
+ input_ids = self.tokenizer.apply_chat_template(
 
 
 
66
  [{"role": "user", "content": prompt}],
67
  add_generation_prompt=True,
68
  return_tensors="pt",
69
  tokenize=True,
70
+ ).to(self.model.device)
71
 
72
+ output = self.model.generate(
73
  input_ids,
74
  do_sample=True,
75
  temperature=0.3,
76
  min_p=0.15,
77
  repetition_penalty=1.05,
78
+ max_new_tokens=max_new_tokens,
79
  )
80
 
81
+ output = self.tokenizer.decode(output[0], skip_special_tokens=False)
 
82
  return output
83
 
84
  def __call__(self, prompt: str, max_new_tokens=512, **kwargs):
 
86
  return self.generate(prompt, max_new_tokens, **kwargs)
87
 
88
  # Create the model instance
89
+ wrapped_model = LocalLlamaModel(model, tokenizer)
90
 
91
  # Now create your agents - these should work with the wrapped model
92
+ reviewer_agent = ToolCallingAgent(model=wrapped_model, tools=[])
93
+ model_agent = ToolCallingAgent(model=wrapped_model, tools=[fetch_webpage])
94
  gaia_agent = CodeAgent(
95
+ tools=[fetch_webpage, get_youtube_title_description, get_youtube_transcript],
96
+ model=wrapped_model
97
  )
98
 
99
+
100
+
101
  if __name__ == "__main__":
102
  # Example usage
103
  question = "What was the actual enrollment of the Malko competition in 2023?"