bwilkie commited on
Commit
c06e6b5
·
verified ·
1 Parent(s): da62d97

Update myagent.py

Browse files
Files changed (1) hide show
  1. myagent.py +55 -113
myagent.py CHANGED
@@ -8,139 +8,81 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  import torch
9
 
10
  # --- Basic Agent Definition ---
11
- class BasicAgent:
12
- def __init__(self):
13
- print("BasicAgent initialized.")
14
-
15
- def __call__(self, question: str) -> str:
16
- print(f"Agent received question (first 50 chars): {question[:50]}...")
17
 
18
- try:
19
- # Use the reviewer agent to determine if the question can be answered by a model or requires code
20
- print("Calling reviewer agent...")
21
- reviewer_answer = reviewer_agent.run(myprompts.review_prompt + "\nThe question is:\n" + question)
22
- print(f"Reviewer agent answer: {reviewer_answer}")
23
-
24
- question = question + '\n' + myprompts.output_format
25
- fixed_answer = ""
26
-
27
- if reviewer_answer == "code":
28
- fixed_answer = gaia_agent.run(question)
29
- print(f"Code agent answer: {fixed_answer}")
30
-
31
- elif reviewer_answer == "model":
32
- # If the reviewer agent suggests using the model, we can proceed with the model agent
33
- print("Using model agent to answer the question.")
34
- fixed_answer = model_agent.run(myprompts.model_prompt + "\nThe question is:\n" + question)
35
- print(f"Model agent answer: {fixed_answer}")
36
-
37
- return fixed_answer
38
- except Exception as e:
39
- error = f"An error occurred while processing the question: {e}"
40
- print(error)
41
- return error
42
-
43
- # Load model and tokenizer
44
- model_id = "LiquidAI/LFM2-1.2B"
45
- model = AutoModelForCausalLM.from_pretrained(
46
- model_id,
47
- device_map="auto",
48
- torch_dtype=torch.bfloat16, # Fixed: was string, should be torch dtype
49
- trust_remote_code=True,
50
- # attn_implementation="flash_attention_2" # <- uncomment on compatible GPU
51
- )
52
- tokenizer = AutoTokenizer.from_pretrained(model_id)
53
-
54
- # Create a wrapper class that matches the expected interface
55
  class LocalLlamaModel:
56
  def __init__(self, model, tokenizer):
57
  self.model = model
58
  self.tokenizer = tokenizer
59
  self.device = model.device if hasattr(model, 'device') else 'cpu'
60
-
61
- def _extract_text_from_messages(self, messages):
62
- """Extract text content from ChatMessage objects or handle string input"""
63
- if isinstance(messages, str):
64
- return messages
65
- elif isinstance(messages, list):
66
- # Handle list of ChatMessage objects
67
- text_parts = []
68
- for msg in messages:
69
- if hasattr(msg, 'content'):
70
- # Handle ChatMessage with content attribute
71
- if isinstance(msg.content, list):
72
- # Content is a list of content items
73
- for content_item in msg.content:
74
- if isinstance(content_item, dict) and 'text' in content_item:
75
- text_parts.append(content_item['text'])
76
- elif hasattr(content_item, 'text'):
77
- text_parts.append(content_item.text)
78
- elif isinstance(msg.content, str):
79
- text_parts.append(msg.content)
80
- elif isinstance(msg, dict) and 'content' in msg:
81
- # Handle dictionary format
82
- text_parts.append(str(msg['content']))
83
- else:
84
- # Fallback: convert to string
85
- text_parts.append(str(msg))
86
- return '\n'.join(text_parts)
87
  else:
88
- return str(messages)
89
-
90
- def generate(self, prompt, max_new_tokens=512*5, **kwargs):
91
  try:
 
 
 
 
 
 
 
 
92
 
93
- print("Prompt: ", prompt)
94
- print("Prompt type: ", type(prompt))
95
- # Extract text from the prompt (which might be ChatMessage objects)
96
- text_prompt = self._extract_text_from_messages(prompt)
97
-
98
- print("Extracted text prompt:", text_prompt[:200] + "..." if len(text_prompt) > 200 else text_prompt)
99
-
100
- # Tokenize the text prompt
101
- inputs = self.tokenizer(text_prompt, return_tensors="pt").to(self.model.device)
102
- input_ids = inputs['input_ids']
103
-
104
- # Generate output
105
  with torch.no_grad():
106
  output = self.model.generate(
107
- input_ids,
108
  do_sample=True,
109
  temperature=0.3,
110
  min_p=0.15,
111
  repetition_penalty=1.05,
112
  max_new_tokens=max_new_tokens,
113
- pad_token_id=self.tokenizer.eos_token_id, # Handle padding
114
  )
115
-
116
- # Decode only the new tokens (exclude the input)
117
  new_tokens = output[0][len(input_ids[0]):]
118
- response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
119
-
120
- return response.strip()
121
-
122
- except Exception as e:
123
- print(f"Error in model generation: {e}")
124
- return f"Error generating response: {str(e)}"
125
 
126
- def __call__(self, prompt, max_new_tokens=512, **kwargs):
127
- """Make the model callable like a function"""
128
- return self.generate(prompt, max_new_tokens, **kwargs)
129
 
130
- # Create the model instance
131
- wrapped_model = LocalLlamaModel(model, tokenizer)
 
 
 
 
132
 
133
- # Now create your agents - these should work with the wrapped model
134
- reviewer_agent = ToolCallingAgent(model=wrapped_model, tools=[])
135
- model_agent = ToolCallingAgent(model=wrapped_model, tools=[fetch_webpage])
136
- gaia_agent = CodeAgent(
137
- tools=[fetch_webpage, get_youtube_title_description, get_youtube_transcript],
138
- model=wrapped_model
139
- )
 
 
 
 
140
 
 
141
  if __name__ == "__main__":
142
- # Example usage
143
- question = "What was the actual enrollment of the Malko competition in 2023?"
144
- agent = BasicAgent()
145
- answer = agent(question)
146
- print(f"Answer: {answer}")
 
 
 
 
8
  import torch
9
 
10
  # --- Basic Agent Definition ---
 
 
 
 
 
 
11
 
12
+ # Basic model wrapper for local inference with debug info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class LocalLlamaModel:
14
  def __init__(self, model, tokenizer):
15
  self.model = model
16
  self.tokenizer = tokenizer
17
  self.device = model.device if hasattr(model, 'device') else 'cpu'
18
+ print(f"Model device: {self.device}")
19
+
20
+ def _extract_prompt(self, prompt):
21
+ if isinstance(prompt, str):
22
+ return prompt
23
+ elif isinstance(prompt, list):
24
+ # Convert list of ChatMessages or dicts to plain text
25
+ return "\n".join(
26
+ msg.content if hasattr(msg, "content") else msg.get("content", str(msg))
27
+ for msg in prompt
28
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  else:
30
+ return str(prompt)
31
+
32
+ def generate(self, prompt, max_new_tokens=512):
33
  try:
34
+ print("\n[DEBUG] Raw prompt input:", prompt)
35
+ text_prompt = self._extract_prompt(prompt)
36
+ print("[DEBUG] Extracted prompt text:", text_prompt[:200] + "..." if len(text_prompt) > 200 else text_prompt)
37
+
38
+ inputs = self.tokenizer(text_prompt, return_tensors="pt").to(self.device)
39
+ input_ids = inputs["input_ids"]
40
+
41
+ print("[DEBUG] Tokenized input shape:", input_ids.shape)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  with torch.no_grad():
44
  output = self.model.generate(
45
+ input_ids=input_ids,
46
  do_sample=True,
47
  temperature=0.3,
48
  min_p=0.15,
49
  repetition_penalty=1.05,
50
  max_new_tokens=max_new_tokens,
51
+ pad_token_id=self.tokenizer.eos_token_id,
52
  )
53
+
 
54
  new_tokens = output[0][len(input_ids[0]):]
55
+ decoded = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
56
+ print("[DEBUG] Decoded output:", decoded.strip())
 
 
 
 
 
57
 
58
+ return decoded.strip()
 
 
59
 
60
+ except Exception as e:
61
+ print(f"[ERROR] Generation failed: {e}")
62
+ return f"Error generating response: {e}"
63
+
64
+ def __call__(self, prompt, max_new_tokens=512):
65
+ return self.generate(prompt, max_new_tokens)
66
 
67
+ # Load your model and tokenizer
68
+ def load_model(model_id="LiquidAI/LFM2-1.2B"):
69
+ print(f"Loading model: {model_id}")
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ model_id,
72
+ device_map="auto",
73
+ torch_dtype=torch.bfloat16,
74
+ trust_remote_code=True,
75
+ )
76
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
77
+ return LocalLlamaModel(model, tokenizer)
78
 
79
+ # Run minimal test
80
  if __name__ == "__main__":
81
+ model = load_model()
82
+
83
+ # Example prompt
84
+ prompt = "What is the capital of France?"
85
+
86
+ print("\n[TEST] Asking a simple question...")
87
+ response = model(prompt)
88
+ print("\nFinal Answer:", response)