File size: 5,963 Bytes
38de785
6f48855
 
25bcda8
38de785
846e1c8
 
38de785
6f48855
436d21b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f48855
 
 
 
 
 
 
 
 
 
 
 
 
436d21b
 
 
 
38de785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f48855
 
 
38de785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f48855
38de785
6f48855
 
 
 
 
 
 
38de785
 
6f48855
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import logging
import streamlit as st

# Install necessary libraries using os.system
os.system("pip install --upgrade pip")
os.system("pip install streamlit llama-cpp-agent huggingface_hub trafilatura beautifulsoup4 requests duckduckgo-search googlesearch-python")

# Attempt to import all required modules
try:
    from llama_cpp import Llama
    from llama_cpp_agent.providers import LlamaCppPythonProvider
    from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
    from llama_cpp_agent.chat_history import BasicChatHistory
    from llama_cpp_agent.chat_history.messages import Roles
    from llama_cpp_agent.llm_output_settings import (
        LlmStructuredOutputSettings,
        LlmStructuredOutputType,
    )
    from llama_cpp_agent.tools import WebSearchTool
    from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
    from utils import CitingSources
    from settings import get_context_by_model, get_messages_formatter_type
except ImportError as e:
    st.error(f"Error importing modules: {e}")
    if 'utils' in str(e):
        st.warning("Mocking utils.CitingSources")
        class CitingSources:
            sources = []

    if 'settings' in str(e):
        st.warning("Mocking settings functions")
        def get_context_by_model(model):
            return 4096

        def get_messages_formatter_type(model):
            return MessagesFormatterType.BASIC

import logging
from huggingface_hub import hf_hub_download

# Download the models
hf_hub_download(
    repo_id="bartowski/Mistral-7B-Instruct-v0.3-GGUF",
    filename="Mistral-7B-Instruct-v0.3-Q6_K.gguf",
    local_dir="./models"
)
hf_hub_download(
    repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF",
    filename="Meta-Llama-3-8B-Instruct-Q6_K.gguf",
    local_dir="./models"
)
hf_hub_download(
    repo_id="TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
    filename="mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf",
    local_dir="./models"
)

# Function to respond to user messages
def respond(message, history, system_message, temperature, top_p, top_k, repeat_penalty):
    model = "mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf"
    max_tokens = 3000
    chat_template = get_messages_formatter_type(model)
    llm = Llama(
        model_path=f"models/{model}",
        flash_attn=True,
        n_gpu_layers=81,
        n_batch=1024,
        n_ctx=get_context_by_model(model),
    )
    provider = LlamaCppPythonProvider(llm)
    logging.info(f"Loaded chat examples: {chat_template}")
    search_tool = WebSearchTool(
        llm_provider=provider,
        message_formatter_type=chat_template,
        max_tokens_search_results=12000,
        max_tokens_per_summary=2048,
    )

    web_search_agent = LlamaCppAgent(
        provider,
        system_prompt=web_search_system_prompt,
        predefined_messages_formatter_type=chat_template,
        debug_output=True,
    )

    answer_agent = LlamaCppAgent(
        provider,
        system_prompt=research_system_prompt,
        predefined_messages_formatter_type=chat_template,
        debug_output=True,
    )

    settings = provider.get_provider_default_settings()
    settings.stream = False
    settings.temperature = temperature
    settings.top_k = top_k
    settings.top_p = top_p
    settings.max_tokens = max_tokens
    settings.repeat_penalty = repeat_penalty

    output_settings = LlmStructuredOutputSettings.from_functions(
        [search_tool.get_tool()]
    )

    messages = BasicChatHistory()

    for msn in history:
        user = {"role": Roles.user, "content": msn[0]}
        assistant = {"role": Roles.assistant, "content": msn[1]}
        messages.add_message(user)
        messages.add_message(assistant)

    result = web_search_agent.get_chat_response(
        message,
        llm_sampling_settings=settings,
        structured_output_settings=output_settings,
        add_message_to_chat_history=False,
        add_response_to_chat_history=False,
        print_output=False,
    )

    outputs = ""

    settings.stream = True
    response_text = answer_agent.get_chat_response(
        f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" +
        result[0]["return_value"],
        role=Roles.tool,
        llm_sampling_settings=settings,
        chat_history=messages,
        returns_streaming_generator=True,
        print_output=False,
    )

    for text in response_text:
        outputs += text
        yield outputs

    output_settings = LlmStructuredOutputSettings.from_pydantic_models(
        [CitingSources], LlmStructuredOutputType.object_instance
    )

    citing_sources = answer_agent.get_chat_response(
        "Cite the sources you used in your response.",
        role=Roles.tool,
        llm_sampling_settings=settings,
        chat_history=messages,
        returns_streaming_generator=False,
        structured_output_settings=output_settings,
        print_output=False,
    )
    outputs += "\n\nSources:\n"
    outputs += "\n".join(citing_sources.sources)
    yield outputs

st.title("Novav2 Web Engine")

message = st.text_input("Enter your message:")
history = st.session_state.get("history", [])
system_message = st.text_area("System message", value=web_search_system_prompt)
temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.45, step=0.1)
top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.95, step=0.05)
top_k = st.slider("Top-k", min_value=0, max_value=100, value=40, step=1)
repeat_penalty = st.slider("Repetition penalty", min_value=0.0, max_value=2.0, value=1.1, step=0.1)

if st.button("Send"):
    response_generator = respond(message, history, system_message, temperature, top_p, top_k, repeat_penalty)
    for response in response_generator:
        st.write(response)
        history.append((message, response))
        st.session_state["history"] = history