Update llm_handler.py
Browse files- llm_handler.py +19 -43
llm_handler.py
CHANGED
|
@@ -2,15 +2,18 @@ import requests
|
|
| 2 |
import json
|
| 3 |
from openai import OpenAI
|
| 4 |
from params import OPENAI_MODEL, OPENAI_API_KEY
|
| 5 |
-
|
| 6 |
-
from llama_cpp_agent import LlamaCppAgent
|
| 7 |
-
from llama_cpp_agent import MessagesFormatterType
|
| 8 |
-
from llama_cpp_agent.providers import LlamaCppPythonProvider
|
| 9 |
|
| 10 |
# Add this at the top of the file
|
| 11 |
local_model_base_url = "http://localhost:11434/v1"
|
| 12 |
anything_llm_workspace = "<input-workspace-name-here>"
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def set_local_model_base_url(url):
|
| 15 |
global local_model_base_url
|
| 16 |
local_model_base_url = url
|
|
@@ -19,35 +22,22 @@ def set_anything_llm_workspace(workspace):
|
|
| 19 |
global anything_llm_workspace
|
| 20 |
anything_llm_workspace = workspace
|
| 21 |
|
| 22 |
-
# Create an instance of the OpenAI class for the local model
|
| 23 |
-
client = OpenAI(api_key="local-model", base_url=local_model_base_url)
|
| 24 |
-
|
| 25 |
-
# Initialize LlamaCpp model and agent
|
| 26 |
-
llama_model = Llama("Arcee-Spark-GGUF/Arcee-Spark-Q4_K_M.gguf", n_batch=1024, n_threads=24, n_gpu_layers=33, n_ctx=4098, verbose=False)
|
| 27 |
-
provider = LlamaCppPythonProvider(llama_model)
|
| 28 |
-
llama_agent = LlamaCppAgent(
|
| 29 |
-
provider,
|
| 30 |
-
system_prompt="You are a helpful assistant.",
|
| 31 |
-
predefined_messages_formatter_type=MessagesFormatterType.MISTRAL,
|
| 32 |
-
debug_output=True
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
# Configure provider settings
|
| 36 |
-
settings = provider.get_provider_default_settings()
|
| 37 |
-
settings.max_tokens = 2000
|
| 38 |
-
settings.stream = True
|
| 39 |
-
|
| 40 |
def send_to_chatgpt(msg_list):
|
| 41 |
try:
|
| 42 |
-
# Update the send_to_chatgpt function to use the dynamic base_url
|
| 43 |
-
client = OpenAI(api_key="local-model", base_url=local_model_base_url)
|
| 44 |
completion = client.chat.completions.create(
|
| 45 |
-
model=
|
|
|
|
| 46 |
temperature=0.6,
|
| 47 |
-
|
| 48 |
)
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
return chatgpt_response, chatgpt_usage
|
| 52 |
except Exception as e:
|
| 53 |
print(f"Error in send_to_chatgpt: {str(e)}")
|
|
@@ -77,24 +67,10 @@ def send_to_anything_llm(msg_list):
|
|
| 77 |
print(f"Error in send_to_anything_llm: {str(e)}")
|
| 78 |
return f"Error: {str(e)}", None
|
| 79 |
|
| 80 |
-
def send_to_llamacpp(msg_list):
|
| 81 |
-
try:
|
| 82 |
-
# Convert the message list to the format expected by LlamaCppAgent
|
| 83 |
-
formatted_messages = [{"role": msg["role"], "content": msg["content"]} for msg in msg_list]
|
| 84 |
-
response = llama_agent(formatted_messages, settings=settings)
|
| 85 |
-
chatgpt_response = response.message.content
|
| 86 |
-
chatgpt_usage = {"prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "total_tokens": response.usage.total_tokens}
|
| 87 |
-
return chatgpt_response, chatgpt_usage
|
| 88 |
-
except Exception as e:
|
| 89 |
-
print(f"Error in send_to_llamacpp: {str(e)}")
|
| 90 |
-
return f"Error: {str(e)}", None
|
| 91 |
-
|
| 92 |
def send_to_llm(provider, msg_list):
|
| 93 |
if provider == "local-model":
|
| 94 |
return send_to_chatgpt(msg_list)
|
| 95 |
elif provider == "anything-llm":
|
| 96 |
return send_to_anything_llm(msg_list)
|
| 97 |
-
elif provider == "llamacpp":
|
| 98 |
-
return send_to_llamacpp(msg_list)
|
| 99 |
else:
|
| 100 |
-
raise ValueError(f"Unknown provider: {provider}")
|
|
|
|
| 2 |
import json
|
| 3 |
from openai import OpenAI
|
| 4 |
from params import OPENAI_MODEL, OPENAI_API_KEY
|
| 5 |
+
import llamanet
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# Add this at the top of the file
|
| 8 |
local_model_base_url = "http://localhost:11434/v1"
|
| 9 |
anything_llm_workspace = "<input-workspace-name-here>"
|
| 10 |
|
| 11 |
+
# Initialize llamanet
|
| 12 |
+
llamanet.run()
|
| 13 |
+
|
| 14 |
+
# Create an instance of the OpenAI class
|
| 15 |
+
client = OpenAI()
|
| 16 |
+
|
| 17 |
def set_local_model_base_url(url):
|
| 18 |
global local_model_base_url
|
| 19 |
local_model_base_url = url
|
|
|
|
| 22 |
global anything_llm_workspace
|
| 23 |
anything_llm_workspace = workspace
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def send_to_chatgpt(msg_list):
|
| 26 |
try:
|
|
|
|
|
|
|
| 27 |
completion = client.chat.completions.create(
|
| 28 |
+
model='https://huggingface.co/arcee-ai/Arcee-Spark-GGUF/blob/main/Arcee-Spark-IQ4_XS.gguf', # This will use the llamanet model
|
| 29 |
+
messages=msg_list,
|
| 30 |
temperature=0.6,
|
| 31 |
+
stream=True
|
| 32 |
)
|
| 33 |
+
|
| 34 |
+
chatgpt_response = ""
|
| 35 |
+
for chunk in completion:
|
| 36 |
+
if chunk.choices[0].delta.content is not None:
|
| 37 |
+
chatgpt_response += chunk.choices[0].delta.content
|
| 38 |
+
|
| 39 |
+
# Note: Usage information might not be available with llamanet
|
| 40 |
+
chatgpt_usage = None
|
| 41 |
return chatgpt_response, chatgpt_usage
|
| 42 |
except Exception as e:
|
| 43 |
print(f"Error in send_to_chatgpt: {str(e)}")
|
|
|
|
| 67 |
print(f"Error in send_to_anything_llm: {str(e)}")
|
| 68 |
return f"Error: {str(e)}", None
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
def send_to_llm(provider, msg_list):
|
| 71 |
if provider == "local-model":
|
| 72 |
return send_to_chatgpt(msg_list)
|
| 73 |
elif provider == "anything-llm":
|
| 74 |
return send_to_anything_llm(msg_list)
|
|
|
|
|
|
|
| 75 |
else:
|
| 76 |
+
raise ValueError(f"Unknown provider: {provider}")
|