File size: 4,049 Bytes
e51386e
 
 
 
 
 
 
8ac5ef4
 
 
72146a4
 
e51386e
580bcf5
8ac5ef4
580bcf5
 
e0860a0
e51386e
580bcf5
e51386e
 
 
 
580bcf5
 
 
 
 
 
 
e51386e
72146a4
580bcf5
 
 
72146a4
580bcf5
 
 
 
72146a4
 
 
 
580bcf5
 
e51386e
580bcf5
 
e0860a0
580bcf5
 
 
e0860a0
580bcf5
 
e51386e
580bcf5
 
72146a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580bcf5
72146a4
8ac5ef4
580bcf5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
try:
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from llama_index.llms.huggingface import HuggingFaceLLM
    import torch
except ImportError as e:
    print(f"Import error in local_llm.py: {e}")
    raise

class LocalLLM:
    def __init__(self):
        # Use a simple, reliable model that works well with LlamaIndex
        self.model_name = "microsoft/DialoGPT-small"  # Changed to smaller model
        print(f"Initializing LocalLLM with model: {self.model_name}")
        self.llm = self._create_llama_index_llm()
    
    def _create_llama_index_llm(self):
        """Create LlamaIndex compatible LLM"""
        try:
            print("Loading tokenizer...")
            tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            print("Loading model...")
            model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None,
                low_cpu_mem_usage=True
            )
            
            print("Creating LlamaIndex LLM...")
            # Fix the generate_kwargs to avoid conflicts
            llm = HuggingFaceLLM(
                model=model,
                tokenizer=tokenizer,
                # Simplified generate_kwargs to avoid conflicts
                generate_kwargs={
                    "do_sample": True,
                    "temperature": 0.7,
                    "pad_token_id": tokenizer.eos_token_id
                },
                # Set these parameters at the LLM level instead
                max_new_tokens=256,
                device_map="auto" if torch.cuda.is_available() else None
            )
            
            print("LLM created successfully!")
            return llm
            
        except Exception as e:
            print(f"Failed to load model {self.model_name}: {str(e)}")
            # Fallback to even simpler model
            return self._create_fallback_llm()
    
    def _create_fallback_llm(self):
        """Fallback to a very basic model"""
        print("Using fallback model: gpt2")
        model_name = "gpt2"
        
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            tokenizer.pad_token = tokenizer.eos_token
            
            model = AutoModelForCausalLM.from_pretrained(model_name)
            
            return HuggingFaceLLM(
                model=model,
                tokenizer=tokenizer,
                generate_kwargs={
                    "do_sample": True,
                    "temperature": 0.7,
                    "pad_token_id": tokenizer.eos_token_id
                },
                max_new_tokens=256
            )
        except Exception as e:
            print(f"Even fallback model failed: {str(e)}")
            # Return a mock LLM for testing
            return self._create_mock_llm()
    
    def _create_mock_llm(self):
        """Create a mock LLM for testing when models fail"""
        print("Creating mock LLM for testing...")
        
        class MockLLM:
            def chat(self, messages, **kwargs):
                # Simple mock response
                class MockResponse:
                    def __init__(self, text):
                        self.message = type('obj', (object,), {'content': text})
                        
                return MockResponse("This is a mock response. The actual LLM failed to load.")
            
            def complete(self, prompt, **kwargs):
                class MockCompletion:
                    def __init__(self, text):
                        self.text = text
                        
                return MockCompletion("Mock completion response.")
        
        return MockLLM()
    
    def get_llm(self):
        """Return the LlamaIndex LLM instance"""
        return self.llm