Ibraaheem commited on
Commit
42247c8
·
1 Parent(s): 420f651

Update private_gpt/components/llm/llm_component.py

Browse files
private_gpt/components/llm/llm_component.py CHANGED
@@ -60,16 +60,48 @@ class LLMComponent:
60
  case "mock":
61
  self.llm = MockLLM()
62
 
 
 
 
 
 
 
 
 
 
 
63
  @inject
64
  def switch_model(self, new_model: str, settings: Settings) -> None:
 
65
  openai_settings = settings.openai.api_key
66
- if type(self.llm) == OpenAI:
67
- if new_model == "gpt-3.5-turbo":
68
- self.llm = OpenAI(model="gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"))
69
- elif new_model == "gpt-4":
70
- # Initialize with the new model
71
- self.llm = OpenAI(model="gpt-4", api_key=os.environ.get("OPENAI_API_KEY"))
72
- logger.info("Initializing the GPT Model in=%s", "gpt-4")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
 
 
60
  case "mock":
61
  self.llm = MockLLM()
62
 
63
+
64
+ case "dynamic":
65
+ from llama_index.llms import OpenAI
66
+ openai_settings = settings.openai.api_key
67
+
68
+ #default startup
69
+ logger.info("Initializing the GPT Model in=%s", "gpt-3.5-turbo")
70
+ self.llm = OpenAI(model="gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"))
71
+
72
+
73
  @inject
74
  def switch_model(self, new_model: str, settings: Settings) -> None:
75
+ from llama_index.llms import LlamaCPP
76
  openai_settings = settings.openai.api_key
77
+
78
+ if new_model == "gpt-3.5-turbo":
79
+ self.llm = OpenAI(model="gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"))
80
+
81
+ elif new_model == "gpt-4":
82
+ # Initialize with the new model
83
+ self.llm = OpenAI(model="gpt-4", api_key=os.environ.get("OPENAI_API_KEY"))
84
+ logger.info("Initializing the GPT Model in=%s", "gpt-4")
85
+
86
+
87
+ elif new_model == "mistral-7B":
88
+ prompt_style_cls = get_prompt_style(settings.local.prompt_style)
89
+ prompt_style = prompt_style_cls(
90
+ default_system_prompt=settings.local.default_system_prompt
91
+ )
92
+ self.llm = LlamaCPP(
93
+ model_path=str(models_path / settings.local.llm_hf_model_file),
94
+ #model_url= "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf?download=true",
95
+ temperature=0.1,
96
+ max_new_tokens=settings.llm.max_new_tokens,
97
+ context_window=3900,
98
+ generate_kwargs={},
99
+ model_kwargs={"n_gpu_layers": -1},
100
+ messages_to_prompt=prompt_style.messages_to_prompt,
101
+ completion_to_prompt=prompt_style.completion_to_prompt,
102
+ verbose=True,
103
+ )
104
+
105
 
106
 
107