File size: 4,849 Bytes
0f25622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d03104
0f25622
 
 
 
8d03104
0f25622
 
b6b0bb1
 
0f25622
 
 
 
 
 
b6b0bb1
 
0f25622
b6b0bb1
0f25622
61f2af8
0f25622
 
 
61f2af8
0f25622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
Hugging Face's logo
Hugging Face
Search models, datasets, users...
Models
Datasets
Spaces
Posts
Docs
Enterprise
Pricing



Spaces:

dromerosm
/
groq-llama3


like
29
App
Files
Community
groq-llama3
/
app.py

dromerosm's picture
dromerosm
Update app.py
81a6356
verified
4 months ago
raw

Copy download link
history
blame
contribute
delete

4.5 kB
import os
from dotenv import find_dotenv, load_dotenv
import streamlit as st
from typing import Generator
from groq import Groq

_ = load_dotenv(find_dotenv())
st.set_page_config(page_icon="πŸ“ƒ", layout="wide", page_title="Groq & LLaMA3.1 Chat Bot...")


def icon(emoji: str):
    """Shows an emoji as a Notion-style page icon."""
    st.write(
        f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
        unsafe_allow_html=True,
    )


# icon("⚑️")

st.subheader("Groq Chat with LLaMA3.1 App", divider="rainbow", anchor=False)

client = Groq(
    api_key=os.environ['GROQ_API_KEY'],
)

# Initialize chat history and selected model
if "messages" not in st.session_state:
    st.session_state.messages = []

if "selected_model" not in st.session_state:
    st.session_state.selected_model = None

# Define model details
models = {
    "llama-3.1-70b-versatile": {"name": "LLaMA3.1-70b", "tokens": 4096, "developer": "Meta"},
    "llama-3.1-8b-instant": {"name": "LLaMA3.1-8b", "tokens": 4096, "developer": "Meta"},
    "llama3-70b-8192": {"name": "Meta Llama 3 70B", "tokens": 4096, "developer": "Meta"},
    "llama3-8b-8192": {"name": "Meta Llama 3 8B", "tokens": 4096, "developer": "Meta"},
    "llama3-groq-70b-8192-tool-use-preview": {"name": "Llama 3 Groq 70B Tool Use (Preview)", "tokens": 4096, "developer": "Groq"},
    "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 4096, "developer": "Google"},
    "mixtral-8x7b-32768": {
        "name": "Mixtral-8x7b-Instruct-v0.1",
        "tokens": 32768,
        "developer": "Mistral",
    },
}

# Layout for model selection and max_tokens slider
col1, col2 = st.columns([1, 3])  # Adjust the ratio to make the first column smaller


with col1:
    model_option = st.selectbox(
        "Choose a model:",
        options=list(models.keys()),
        format_func=lambda x: models[x]["name"],
        index=0,  # Default to the first model in the list
    )
    max_tokens_range = models[model_option]["tokens"]
    max_tokens = st.slider(
        "Max Tokens:",
        min_value=512,
        max_value=max_tokens_range,
        value=min(32768, max_tokens_range),
        step=512,
        help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
    )

# Detect model change and clear chat history if model has changed
if st.session_state.selected_model != model_option:
    st.session_state.messages = []
    st.session_state.selected_model = model_option

# Add a "Clear Chat" button
if st.button("Clear Chat"):
    st.session_state.messages = []
    
# Display chat messages from history on app rerun
for message in st.session_state.messages:
    avatar = "πŸ”‹" if message["role"] == "assistant" else "πŸ§‘β€πŸ’»"
    with st.chat_message(message["role"], avatar=avatar):
        st.markdown(message["content"])


def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
    """Yield chat response content from the Groq API response."""
    for chunk in chat_completion:
        if chunk.choices[0].delta.content:
            yield chunk.choices[0].delta.content


if prompt := st.chat_input("Enter your prompt here..."):
    st.session_state.messages.append({"role": "user", "content": prompt})

    with st.chat_message("user", avatar="πŸ§‘β€πŸ’»"):  
        st.markdown(prompt)

    # Fetch response from Groq API
    try:
        chat_completion = client.chat.completions.create(
            model=model_option,
            messages=[
                {"role": m["role"], "content": m["content"]}
                for m in st.session_state.messages
            ],
            max_tokens=max_tokens,
            stream=True,
        )

        # Use the generator function with st.write_stream
        with st.chat_message("assistant", avatar="πŸ”‹"):
            chat_responses_generator = generate_chat_responses(chat_completion)
            full_response = st.write_stream(chat_responses_generator)
    except Exception as e:
        st.error(e, icon="❌")

    # Append the full response to session_state.messages
    if isinstance(full_response, str):
        st.session_state.messages.append(
            {"role": "assistant", "content": full_response}
        )
    else:
        # Handle the case where full_response is not a string
        combined_response = "\n".join(str(item) for item in full_response)
        st.session_state.messages.append(
            {"role": "assistant", "content": combined_response}
        )