LamiaYT commited on
Commit
2828102
·
1 Parent(s): e0860a0

Fix device mapping and improve error handling

Browse files
Files changed (2) hide show
  1. agent/local_llm.py +14 -14
  2. app.py +11 -3
agent/local_llm.py CHANGED
@@ -4,8 +4,7 @@ from accelerate import Accelerator
4
 
5
  class LocalLLM:
6
  def __init__(self):
7
- self.model_name = "HuggingFaceH4/zephyr-7b-beta"
8
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
  self.pipeline = self._load_model()
10
 
11
  def _load_model(self):
@@ -13,7 +12,7 @@ class LocalLLM:
13
  # First try with 4-bit quantization
14
  return self._load_quantized_model()
15
  except Exception as e:
16
- print(f"4-bit loading failed: {str(e)}. Trying without quantization...")
17
  return self._load_fallback_model()
18
 
19
  def _load_quantized_model(self):
@@ -28,8 +27,7 @@ class LocalLLM:
28
  return pipeline(
29
  "text-generation",
30
  model=model,
31
- tokenizer=tokenizer,
32
- device=self.device
33
  )
34
 
35
  def _load_fallback_model(self):
@@ -42,15 +40,17 @@ class LocalLLM:
42
  return pipeline(
43
  "text-generation",
44
  model=model,
45
- tokenizer=tokenizer,
46
- device=self.device
47
  )
48
 
49
  def generate(self, prompt: str) -> str:
50
- outputs = self.pipeline(
51
- prompt,
52
- max_new_tokens=256,
53
- do_sample=True,
54
- temperature=0.7
55
- )
56
- return outputs[0]['generated_text']
 
 
 
 
4
 
5
  class LocalLLM:
6
  def __init__(self):
7
+ self.model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Using smaller model
 
8
  self.pipeline = self._load_model()
9
 
10
  def _load_model(self):
 
12
  # First try with 4-bit quantization
13
  return self._load_quantized_model()
14
  except Exception as e:
15
+ print(f"Quantized loading failed: {str(e)}. Trying without quantization...")
16
  return self._load_fallback_model()
17
 
18
  def _load_quantized_model(self):
 
27
  return pipeline(
28
  "text-generation",
29
  model=model,
30
+ tokenizer=tokenizer # Removed device parameter
 
31
  )
32
 
33
  def _load_fallback_model(self):
 
40
  return pipeline(
41
  "text-generation",
42
  model=model,
43
+ tokenizer=tokenizer # Removed device parameter
 
44
  )
45
 
46
  def generate(self, prompt: str) -> str:
47
+ try:
48
+ outputs = self.pipeline(
49
+ prompt,
50
+ max_new_tokens=256,
51
+ do_sample=True,
52
+ temperature=0.7
53
+ )
54
+ return outputs[0]['generated_text']
55
+ except Exception as e:
56
+ return f"Error generating response: {str(e)}"
app.py CHANGED
@@ -6,13 +6,21 @@ from utils.gaia_api import GaiaAPI
6
 
7
  # Initialize components
8
  try:
9
- from agent.local_llm import LocalLLM
10
  llm = LocalLLM()
11
  agent = ReActAgent.from_tools(gaia_tools, llm=llm.pipeline)
12
  except Exception as e:
13
- print(f"Failed to initialize LLM: {str(e)}")
14
- # Fallback to a simpler agent if needed
15
  agent = None
 
 
 
 
 
 
 
 
 
 
16
  def process_question(question_text: str) -> str:
17
  """Process GAIA question through agent"""
18
  try:
 
6
 
7
  # Initialize components
8
  try:
 
9
  llm = LocalLLM()
10
  agent = ReActAgent.from_tools(gaia_tools, llm=llm.pipeline)
11
  except Exception as e:
12
+ print(f"Agent initialization failed: {str(e)}")
 
13
  agent = None
14
+
15
+ def process_question(question_text: str) -> str:
16
+ if not agent:
17
+ return "Agent initialization failed - please check logs"
18
+ try:
19
+ response = agent.query(question_text)
20
+ return str(response)
21
+ except Exception as e:
22
+ return f"Error processing question: {str(e)}"
23
+
24
  def process_question(question_text: str) -> str:
25
  """Process GAIA question through agent"""
26
  try: