bwilkie commited on
Commit
d9f0f18
·
verified ·
1 Parent(s): 78f6f4d

Update myagent.py

Browse files
Files changed (1) hide show
  1. myagent.py +57 -39
myagent.py CHANGED
@@ -5,14 +5,14 @@ from tools.fetch import fetch_webpage
5
  from tools.yttranscript import get_youtube_transcript, get_youtube_title_description
6
  import myprompts
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
8
 
9
- import torch
10
  # --- Basic Agent Definition ---
11
  class BasicAgent:
12
  def __init__(self):
13
  print("BasicAgent initialized.")
 
14
  def __call__(self, question: str) -> str:
15
-
16
  print(f"Agent received question (first 50 chars): {question[:50]}...")
17
 
18
  try:
@@ -40,14 +40,12 @@ class BasicAgent:
40
  print(error)
41
  return error
42
 
43
-
44
-
45
  # Load model and tokenizer
46
  model_id = "LiquidAI/LFM2-1.2B"
47
  model = AutoModelForCausalLM.from_pretrained(
48
  model_id,
49
  device_map="auto",
50
- torch_dtype="bfloat16",
51
  trust_remote_code=True,
52
  # attn_implementation="flash_attention_2" # <- uncomment on compatible GPU
53
  )
@@ -58,52 +56,74 @@ class LocalLlamaModel:
58
  def __init__(self, model, tokenizer):
59
  self.model = model
60
  self.tokenizer = tokenizer
61
- self.device = 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- def generate(self, prompt: str, max_new_tokens=512*5, **kwargs):
64
  try:
65
- # Generate answer using the provided prompt - following the recommended pattern
66
- # input_ids = self.tokenizer.apply_chat_template(
67
- # [{"role": "user", "content": str(prompt)}],
68
- # add_generation_prompt=True,
69
- # return_tensors="pt",
70
- # tokenize=True,
71
- # ).to(self.model.device)
72
 
73
  print("Prompt: ", prompt)
74
  print("Prompt type: ", type(prompt))
 
 
75
 
76
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
77
 
78
- # Generate output - exactly as in recommended code
79
- output = self.model.generate(
80
- input_ids,
81
- do_sample=True,
82
- temperature=0.3,
83
- min_p=0.15,
84
- repetition_penalty=1.05,
85
- max_new_tokens=max_new_tokens,
86
- )
87
 
88
- # Decode the full output - as in recommended code
89
- decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=False)
 
 
 
 
 
 
 
 
 
90
 
91
- # Extract only the assistant's response (after the last <|im_start|>assistant)
92
- if "<|im_start|>assistant" in decoded_output:
93
- assistant_response = decoded_output.split("<|im_start|>assistant")[-1]
94
- # Remove any trailing special tokens
95
- assistant_response = assistant_response.replace("<|im_end|>", "").strip()
96
- return assistant_response
97
- else:
98
- # Fallback: return the full decoded output
99
- return decoded_output
100
 
101
  except Exception as e:
102
  print(f"Error in model generation: {e}")
103
  return f"Error generating response: {str(e)}"
104
 
105
-
106
- def __call__(self, prompt: str, max_new_tokens=512, **kwargs):
107
  """Make the model callable like a function"""
108
  return self.generate(prompt, max_new_tokens, **kwargs)
109
 
@@ -118,8 +138,6 @@ gaia_agent = CodeAgent(
118
  model=wrapped_model
119
  )
120
 
121
-
122
-
123
  if __name__ == "__main__":
124
  # Example usage
125
  question = "What was the actual enrollment of the Malko competition in 2023?"
 
5
  from tools.yttranscript import get_youtube_transcript, get_youtube_title_description
6
  import myprompts
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
+ t torch
9
 
 
10
  # --- Basic Agent Definition ---
11
  class BasicAgent:
12
  def __init__(self):
13
  print("BasicAgent initialized.")
14
+
15
  def __call__(self, question: str) -> str:
 
16
  print(f"Agent received question (first 50 chars): {question[:50]}...")
17
 
18
  try:
 
40
  print(error)
41
  return error
42
 
 
 
43
  # Load model and tokenizer
44
  model_id = "LiquidAI/LFM2-1.2B"
45
  model = AutoModelForCausalLM.from_pretrained(
46
  model_id,
47
  device_map="auto",
48
+ torch_dtype=torch.bfloat16, # Fixed: was string, should be torch dtype
49
  trust_remote_code=True,
50
  # attn_implementation="flash_attention_2" # <- uncomment on compatible GPU
51
  )
 
56
  def __init__(self, model, tokenizer):
57
  self.model = model
58
  self.tokenizer = tokenizer
59
+ self.device = model.device if hasattr(model, 'device') else 'cpu'
60
+
61
+ def _extract_text_from_messages(self, messages):
62
+ """Extract text content from ChatMessage objects or handle string input"""
63
+ if isinstance(messages, str):
64
+ return messages
65
+ elif isinstance(messages, list):
66
+ # Handle list of ChatMessage objects
67
+ text_parts = []
68
+ for msg in messages:
69
+ if hasattr(msg, 'content'):
70
+ # Handle ChatMessage with content attribute
71
+ if isinstance(msg.content, list):
72
+ # Content is a list of content items
73
+ for content_item in msg.content:
74
+ if isinstance(content_item, dict) and 'text' in content_item:
75
+ text_parts.append(content_item['text'])
76
+ elif hasattr(content_item, 'text'):
77
+ text_parts.append(content_item.text)
78
+ elif isinstance(msg.content, str):
79
+ text_parts.append(msg.content)
80
+ elif isinstance(msg, dict) and 'content' in msg:
81
+ # Handle dictionary format
82
+ text_parts.append(str(msg['content']))
83
+ else:
84
+ # Fallback: convert to string
85
+ text_parts.append(str(msg))
86
+ return '\n'.join(text_parts)
87
+ else:
88
+ return str(messages)
89
 
90
+ def generate(self, prompt, max_new_tokens=512*5, **kwargs):
91
  try:
 
 
 
 
 
 
 
92
 
93
  print("Prompt: ", prompt)
94
  print("Prompt type: ", type(prompt))
95
+ # Extract text from the prompt (which might be ChatMessage objects)
96
+ text_prompt = self._extract_text_from_messages(prompt)
97
 
98
+ print("Extracted text prompt:", text_prompt[:200] + "..." if len(text_prompt) > 200 else text_prompt)
99
 
100
+ # Tokenize the text prompt
101
+ inputs = self.tokenizer(text_prompt, return_tensors="pt").to(self.model.device)
102
+ input_ids = inputs['input_ids']
 
 
 
 
 
 
103
 
104
+ # Generate output
105
+ with torch.no_grad():
106
+ output = self.model.generate(
107
+ input_ids,
108
+ do_sample=True,
109
+ temperature=0.3,
110
+ min_p=0.15,
111
+ repetition_penalty=1.05,
112
+ max_new_tokens=max_new_tokens,
113
+ pad_token_id=self.tokenizer.eos_token_id, # Handle padding
114
+ )
115
 
116
+ # Decode only the new tokens (exclude the input)
117
+ new_tokens = output[0][len(input_ids[0]):]
118
+ response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
119
+
120
+ return response.strip()
 
 
 
 
121
 
122
  except Exception as e:
123
  print(f"Error in model generation: {e}")
124
  return f"Error generating response: {str(e)}"
125
 
126
+ def __call__(self, prompt, max_new_tokens=512, **kwargs):
 
127
  """Make the model callable like a function"""
128
  return self.generate(prompt, max_new_tokens, **kwargs)
129
 
 
138
  model=wrapped_model
139
  )
140
 
 
 
141
  if __name__ == "__main__":
142
  # Example usage
143
  question = "What was the actual enrollment of the Malko competition in 2023?"