|
import streamlit as st |
|
import os |
|
import json |
|
import pandas as pd |
|
import random |
|
from os.path import join |
|
from datetime import datetime |
|
from src import ( |
|
preprocess_and_load_df, |
|
get_from_user, |
|
ask_question, |
|
) |
|
from dotenv import load_dotenv |
|
from langchain_groq import ChatGroq |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from streamlit_feedback import streamlit_feedback |
|
from huggingface_hub import HfApi |
|
from datasets import load_dataset, get_dataset_config_info, Dataset |
|
from PIL import Image |
|
import time |
|
import uuid |
|
|
|
|
|
st.set_page_config( |
|
page_title="VayuChat - AI Air Quality Assistant", |
|
page_icon="V", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
/* Clean app background */ |
|
.stApp { |
|
background-color: #ffffff; |
|
color: #212529; |
|
font-family: 'Segoe UI', sans-serif; |
|
} |
|
|
|
/* Sidebar */ |
|
[data-testid="stSidebar"] { |
|
background-color: #f8f9fa; |
|
border-right: 1px solid #dee2e6; |
|
padding: 1rem; |
|
} |
|
|
|
/* Main title */ |
|
.main-title { |
|
text-align: center; |
|
color: #343a40; |
|
font-size: 2.5rem; |
|
font-weight: 700; |
|
margin-bottom: 0.5rem; |
|
} |
|
|
|
/* Subtitle */ |
|
.subtitle { |
|
text-align: center; |
|
color: #6c757d; |
|
font-size: 1.1rem; |
|
margin-bottom: 1.5rem; |
|
} |
|
|
|
/* Instructions */ |
|
.instructions { |
|
background-color: #f1f3f5; |
|
border-left: 4px solid #0d6efd; |
|
padding: 1rem; |
|
margin-bottom: 1.5rem; |
|
border-radius: 6px; |
|
color: #495057; |
|
text-align: left; |
|
} |
|
|
|
/* Quick prompt buttons */ |
|
.quick-prompt-container { |
|
display: flex; |
|
flex-wrap: wrap; |
|
gap: 8px; |
|
margin-bottom: 1.5rem; |
|
padding: 1rem; |
|
background-color: #f8f9fa; |
|
border-radius: 10px; |
|
border: 1px solid #dee2e6; |
|
} |
|
|
|
.quick-prompt-btn { |
|
background-color: #0d6efd; |
|
color: white; |
|
border: none; |
|
padding: 8px 16px; |
|
border-radius: 20px; |
|
font-size: 0.9rem; |
|
cursor: pointer; |
|
transition: all 0.2s ease; |
|
white-space: nowrap; |
|
} |
|
|
|
.quick-prompt-btn:hover { |
|
background-color: #0b5ed7; |
|
transform: translateY(-2px); |
|
} |
|
|
|
/* User message styling */ |
|
.user-message { |
|
background: #3b82f6; |
|
color: white; |
|
padding: 0.75rem 1rem; |
|
border-radius: 12px; |
|
max-width: 70%; |
|
} |
|
|
|
.user-info { |
|
font-size: 0.875rem; |
|
opacity: 0.9; |
|
margin-bottom: 3px; |
|
} |
|
|
|
/* Assistant message styling */ |
|
.assistant-message { |
|
background: #f1f5f9; |
|
color: #334155; |
|
padding: 0.75rem 1rem; |
|
border-radius: 12px; |
|
max-width: 70%; |
|
} |
|
|
|
.assistant-info { |
|
font-size: 0.875rem; |
|
color: #6b7280; |
|
margin-bottom: 5px; |
|
} |
|
|
|
/* Processing indicator */ |
|
.processing-indicator { |
|
background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%); |
|
color: #333; |
|
padding: 15px 20px; |
|
border-radius: 20px 20px 20px 5px; |
|
margin: 10px 0; |
|
margin-left: 0; |
|
margin-right: auto; |
|
max-width: 80%; |
|
position: relative; |
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1); |
|
animation: pulse 2s infinite; |
|
} |
|
|
|
@keyframes pulse { |
|
0% { opacity: 1; } |
|
50% { opacity: 0.7; } |
|
100% { opacity: 1; } |
|
} |
|
|
|
/* Feedback box */ |
|
.feedback-section { |
|
background-color: #f8f9fa; |
|
border: 1px solid #dee2e6; |
|
padding: 1rem; |
|
border-radius: 8px; |
|
margin: 1rem 0; |
|
} |
|
|
|
/* Success and error messages */ |
|
.success-message { |
|
background-color: #d1e7dd; |
|
color: #0f5132; |
|
padding: 1rem; |
|
border-radius: 6px; |
|
border: 1px solid #badbcc; |
|
} |
|
|
|
.error-message { |
|
background-color: #f8d7da; |
|
color: #842029; |
|
padding: 1rem; |
|
border-radius: 6px; |
|
border: 1px solid #f5c2c7; |
|
} |
|
|
|
/* Chat input */ |
|
.stChatInput { |
|
border-radius: 6px; |
|
border: 1px solid #ced4da; |
|
background: #ffffff; |
|
} |
|
|
|
/* Button */ |
|
.stButton > button { |
|
background-color: #0d6efd; |
|
color: white; |
|
border-radius: 6px; |
|
padding: 0.5rem 1.25rem; |
|
border: none; |
|
font-weight: 600; |
|
transition: background-color 0.2s ease; |
|
} |
|
|
|
.stButton > button:hover { |
|
background-color: #0b5ed7; |
|
} |
|
|
|
/* Sidebar button styling */ |
|
[data-testid="stSidebar"] .stButton > button { |
|
background-color: #f8fafc; |
|
color: #475569; |
|
border: 1px solid #e2e8f0; |
|
padding: 0.375rem 0.75rem; |
|
font-size: 0.75rem; |
|
font-weight: normal; |
|
text-align: left; |
|
} |
|
|
|
[data-testid="stSidebar"] .stButton > button:hover { |
|
background-color: #e0f2fe; |
|
border-color: #0ea5e9; |
|
color: #0c4a6e; |
|
} |
|
|
|
/* Code container styling */ |
|
.code-container { |
|
margin: 1rem 0; |
|
border: 1px solid #e2e8f0; |
|
border-radius: 8px; |
|
background: #f8fafc; |
|
} |
|
|
|
.code-header { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
padding: 0.75rem 1rem; |
|
background: #f1f5f9; |
|
border-bottom: 1px solid #e2e8f0; |
|
cursor: pointer; |
|
transition: background-color 0.2s; |
|
} |
|
|
|
.code-header:hover { |
|
background: #e2e8f0; |
|
} |
|
|
|
.code-title { |
|
font-size: 0.875rem; |
|
font-weight: 500; |
|
color: #374151; |
|
} |
|
|
|
.toggle-text { |
|
font-size: 0.75rem; |
|
color: #6b7280; |
|
} |
|
|
|
.code-block { |
|
background: #1e293b; |
|
color: #e2e8f0; |
|
padding: 1rem; |
|
font-family: 'Monaco', 'Menlo', monospace; |
|
font-size: 0.875rem; |
|
overflow-x: auto; |
|
line-height: 1.5; |
|
} |
|
|
|
.answer-container { |
|
background: #f8fafc; |
|
border: 1px solid #e2e8f0; |
|
border-radius: 8px; |
|
padding: 1.5rem; |
|
margin: 1rem 0; |
|
} |
|
|
|
.answer-text { |
|
font-size: 1.125rem; |
|
color: #1e293b; |
|
line-height: 1.6; |
|
margin-bottom: 1rem; |
|
} |
|
|
|
.answer-highlight { |
|
background: #fef3c7; |
|
padding: 0.125rem 0.375rem; |
|
border-radius: 4px; |
|
font-weight: 600; |
|
color: #92400e; |
|
} |
|
|
|
.context-info { |
|
background: #f1f5f9; |
|
border-left: 4px solid #3b82f6; |
|
padding: 0.75rem 1rem; |
|
margin: 1rem 0; |
|
font-size: 0.875rem; |
|
color: #475569; |
|
} |
|
|
|
/* Hide default menu and footer */ |
|
#MainMenu {visibility: hidden;} |
|
footer {visibility: hidden;} |
|
header {visibility: hidden;} |
|
|
|
/* Auto scroll */ |
|
.main-container { |
|
height: 70vh; |
|
overflow-y: auto; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown(""" |
|
<script> |
|
function scrollToBottom() { |
|
setTimeout(function() { |
|
const mainContainer = document.querySelector('.main-container'); |
|
if (mainContainer) { |
|
mainContainer.scrollTop = mainContainer.scrollHeight; |
|
} |
|
window.scrollTo(0, document.body.scrollHeight); |
|
}, 100); |
|
} |
|
|
|
function toggleCode(header) { |
|
const codeBlock = header.nextElementSibling; |
|
const toggleText = header.querySelector('.toggle-text'); |
|
|
|
if (codeBlock.style.display === 'none') { |
|
codeBlock.style.display = 'block'; |
|
toggleText.textContent = 'Click to collapse'; |
|
} else { |
|
codeBlock.style.display = 'none'; |
|
toggleText.textContent = 'Click to expand'; |
|
} |
|
} |
|
</script> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
load_dotenv(override=True) |
|
|
|
|
|
Groq_Token = os.getenv("GROQ_API_KEY") |
|
hf_token = os.getenv("HF_TOKEN") |
|
gemini_token = os.getenv("GEMINI_TOKEN") |
|
|
|
models = { |
|
"gpt-oss-20b": "openai/gpt-oss-20b", |
|
"gpt-oss-120b": "openai/gpt-oss-120b", |
|
"llama3.1": "llama-3.1-8b-instant", |
|
"llama3.3": "llama-3.3-70b-versatile", |
|
"deepseek-R1": "deepseek-r1-distill-llama-70b", |
|
"llama4 maverik":"meta-llama/llama-4-maverick-17b-128e-instruct", |
|
"llama4 scout":"meta-llama/llama-4-scout-17b-16e-instruct", |
|
"gemini-pro": "gemini-1.5-pro" |
|
} |
|
|
|
self_path = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
if "session_id" not in st.session_state: |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
|
|
def upload_feedback(feedback, error, output, last_prompt, code, status): |
|
"""Enhanced feedback upload function with better logging and error handling""" |
|
try: |
|
if not hf_token or hf_token.strip() == "": |
|
st.warning("Cannot upload feedback - HF_TOKEN not available") |
|
return False |
|
|
|
|
|
feedback_data = { |
|
"timestamp": datetime.now().isoformat(), |
|
"session_id": st.session_state.session_id, |
|
"feedback_score": feedback.get("score", ""), |
|
"feedback_comment": feedback.get("text", ""), |
|
"user_prompt": last_prompt, |
|
"ai_output": str(output), |
|
"generated_code": code or "", |
|
"error_message": error or "", |
|
"is_image_output": status.get("is_image", False), |
|
"success": not bool(error) |
|
} |
|
|
|
|
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
random_id = str(uuid.uuid4())[:8] |
|
folder_name = f"feedback_{timestamp_str}_{random_id}" |
|
|
|
|
|
markdown_content = f"""# VayuChat Feedback Report |
|
|
|
## Session Information |
|
- **Timestamp**: {feedback_data['timestamp']} |
|
- **Session ID**: {feedback_data['session_id']} |
|
|
|
## User Interaction |
|
**Prompt**: {feedback_data['user_prompt']} |
|
|
|
## AI Response |
|
**Output**: {feedback_data['ai_output']} |
|
|
|
## Generated Code |
|
```python |
|
{feedback_data['generated_code']} |
|
``` |
|
|
|
## Technical Details |
|
- **Error Message**: {feedback_data['error_message']} |
|
- **Is Image Output**: {feedback_data['is_image_output']} |
|
- **Success**: {feedback_data['success']} |
|
|
|
## User Feedback |
|
- **Score**: {feedback_data['feedback_score']} |
|
- **Comments**: {feedback_data['feedback_comment']} |
|
""" |
|
|
|
|
|
markdown_filename = f"{folder_name}.md" |
|
markdown_local_path = f"/tmp/{markdown_filename}" |
|
|
|
with open(markdown_local_path, "w", encoding="utf-8") as f: |
|
f.write(markdown_content) |
|
|
|
|
|
api = HfApi(token=hf_token) |
|
|
|
|
|
api.upload_file( |
|
path_or_fileobj=markdown_local_path, |
|
path_in_repo=f"data/{markdown_filename}", |
|
repo_id="SustainabilityLabIITGN/VayuChat_Feedback", |
|
repo_type="dataset", |
|
) |
|
|
|
|
|
if status.get("is_image", False) and isinstance(output, str) and os.path.exists(output): |
|
try: |
|
image_filename = f"{folder_name}_plot.png" |
|
api.upload_file( |
|
path_or_fileobj=output, |
|
path_in_repo=f"data/{image_filename}", |
|
repo_id="SustainabilityLabIITGN/VayuChat_Feedback", |
|
repo_type="dataset", |
|
) |
|
except Exception as img_error: |
|
print(f"Error uploading image: {img_error}") |
|
|
|
|
|
if os.path.exists(markdown_local_path): |
|
os.remove(markdown_local_path) |
|
|
|
st.success("Feedback uploaded successfully!") |
|
return True |
|
|
|
except Exception as e: |
|
st.error(f"Error uploading feedback: {e}") |
|
print(f"Feedback upload error: {e}") |
|
return False |
|
|
|
|
|
available_models = [] |
|
model_names = list(models.keys()) |
|
groq_models = [] |
|
gemini_models = [] |
|
for model_name in model_names: |
|
if "gemini" not in model_name: |
|
groq_models.append(model_name) |
|
else: |
|
gemini_models.append(model_name) |
|
if Groq_Token and Groq_Token.strip(): |
|
available_models.extend(groq_models) |
|
if gemini_token and gemini_token.strip(): |
|
available_models.extend(gemini_models) |
|
|
|
if not available_models: |
|
st.error("No API keys available! Please set up your API keys in the .env file") |
|
st.stop() |
|
|
|
|
|
default_index = 0 |
|
if "deepseek-R1" in available_models: |
|
default_index = available_models.index("deepseek-R1") |
|
|
|
|
|
header_col1, header_col2 = st.columns([2, 1]) |
|
|
|
with header_col1: |
|
st.markdown(""" |
|
<div style='display: flex; align-items: center; gap: 0.75rem; margin-bottom: 0.5rem;'> |
|
<div style='width: 28px; height: 28px; background: #3b82f6; border-radius: 6px; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 0.875rem;'>V</div> |
|
<div> |
|
<h1 style='margin: 0; font-size: 1.125rem; font-weight: 600; color: #1e293b;'>VayuChat</h1> |
|
<p style='margin: 0; font-size: 0.75rem; color: #64748b;'>Environmental Data Analysis</p> |
|
</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
with header_col2: |
|
st.markdown("<p style='margin: 0 0 0.25rem 0; font-size: 0.75rem; color: #6b7280;'>AI Model:</p>", unsafe_allow_html=True) |
|
model_name = st.selectbox( |
|
"Model:", |
|
available_models, |
|
index=default_index, |
|
help="Choose your AI model", |
|
label_visibility="collapsed" |
|
) |
|
|
|
st.markdown("<hr style='margin: 0.5rem 0; border: none; border-top: 1px solid #e2e8f0;'>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
try: |
|
df = preprocess_and_load_df(join(self_path, "Data.csv")) |
|
|
|
except Exception as e: |
|
st.error(f"Error loading data: {e}") |
|
st.stop() |
|
|
|
inference_server = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2" |
|
image_path = "IITGN_Logo.png" |
|
|
|
|
|
with st.sidebar: |
|
|
|
st.markdown("### Quick Queries") |
|
|
|
|
|
questions = [] |
|
questions_file = join(self_path, "questions.txt") |
|
if os.path.exists(questions_file): |
|
try: |
|
with open(questions_file, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
questions = [q.strip() for q in content.split("\n") if q.strip()] |
|
except Exception as e: |
|
questions = [] |
|
|
|
|
|
if not questions: |
|
questions = [ |
|
"Which month had highest pollution?", |
|
"Which city has worst air quality?", |
|
"Show annual PM2.5 average", |
|
"Compare winter vs summer pollution", |
|
"List all cities by pollution level", |
|
"Plot monthly average PM2.5 for 2023" |
|
] |
|
|
|
|
|
selected_prompt = None |
|
for i, question in enumerate(questions[:6]): |
|
|
|
display_text = question[:35] + "..." if len(question) > 35 else question |
|
|
|
|
|
if st.button(display_text, key=f"sidebar_prompt_{i}", help=question, use_container_width=True): |
|
selected_prompt = question |
|
|
|
st.markdown("---") |
|
|
|
|
|
st.markdown("### Dataset Info") |
|
st.markdown(""" |
|
<div style='background-color: #f1f5f9; padding: 0.75rem; border-radius: 6px; margin-bottom: 1rem;'> |
|
<h4 style='margin: 0 0 0.25rem 0; color: #1e293b; font-size: 0.9rem;'>PM2.5 Air Quality Data</h4> |
|
<p style='margin: 0.125rem 0; font-size: 0.75rem;'><strong>Locations:</strong> Gujarat cities</p> |
|
<p style='margin: 0.125rem 0; font-size: 0.75rem;'><strong>Parameters:</strong> PM2.5, PM10</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("### Current Model") |
|
st.markdown(f"**{model_name}**") |
|
|
|
model_descriptions = { |
|
"llama3.1": "Fast and efficient for general queries", |
|
"llama3.3": "Most advanced LLaMA model for complex reasoning", |
|
"mistral": "Balanced performance and speed", |
|
"gemma": "Google's lightweight model", |
|
"gemini-pro": "Google's most powerful model", |
|
"gpt-oss-20b": "OpenAI's compact open-weight GPT for everyday tasks", |
|
"gpt-oss-120b": "OpenAI's massive open-weight GPT for nuanced responses", |
|
"deepseek-R1": "DeepSeek's distilled LLaMA model for efficient reasoning", |
|
"llama4 maverik": "Meta's LLaMA 4 Maverick β high-performance instruction model", |
|
"llama4 scout": "Meta's LLaMA 4 Scout β optimized for adaptive reasoning" |
|
} |
|
|
|
if model_name in model_descriptions: |
|
st.caption(model_descriptions[model_name]) |
|
|
|
st.markdown("---") |
|
|
|
|
|
if st.button("Clear Chat", use_container_width=True): |
|
st.session_state.responses = [] |
|
st.session_state.processing = False |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
try: |
|
st.rerun() |
|
except AttributeError: |
|
st.experimental_rerun() |
|
|
|
|
|
|
|
|
|
if "responses" not in st.session_state: |
|
st.session_state.responses = [] |
|
if "processing" not in st.session_state: |
|
st.session_state.processing = False |
|
|
|
def show_custom_response(response): |
|
"""Custom response display function with improved styling""" |
|
role = response.get("role", "assistant") |
|
content = response.get("content", "") |
|
|
|
if role == "user": |
|
|
|
st.markdown(f""" |
|
<div style='display: flex; justify-content: flex-end; margin: 1rem 0;'> |
|
<div class='user-message'> |
|
{content} |
|
</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
elif role == "assistant": |
|
|
|
is_image_path = isinstance(content, str) and any(ext in content for ext in ['.png', '.jpg', '.jpeg']) |
|
|
|
|
|
if not is_image_path: |
|
st.markdown(f""" |
|
<div style='display: flex; justify-content: flex-start; margin: 1rem 0;'> |
|
<div class='assistant-message'> |
|
<div class='assistant-info'>VayuChat</div> |
|
{content if isinstance(content, str) else str(content)} |
|
</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if response.get("gen_code"): |
|
with st.expander("π View Generated Code", expanded=False): |
|
st.code(response["gen_code"], language="python") |
|
|
|
|
|
try: |
|
if isinstance(content, str) and (content.endswith('.png') or content.endswith('.jpg')): |
|
if os.path.exists(content): |
|
|
|
st.image(content, use_column_width=True) |
|
return {"is_image": True} |
|
|
|
elif isinstance(content, str) and any(ext in content for ext in ['.png', '.jpg']): |
|
|
|
import re |
|
filename_match = re.search(r'([^/\\]+\.(?:png|jpg|jpeg))', content) |
|
if filename_match: |
|
filename = filename_match.group(1) |
|
if os.path.exists(filename): |
|
st.image(filename, use_column_width=True) |
|
return {"is_image": True} |
|
except: |
|
pass |
|
|
|
return {"is_image": False} |
|
|
|
def show_processing_indicator(model_name, question): |
|
"""Show processing indicator""" |
|
st.markdown(f""" |
|
<div class='processing-indicator'> |
|
<div class='assistant-info'>VayuChat β’ Processing with {model_name}</div> |
|
<strong>Question:</strong> {question}<br> |
|
<em>Generating response...</em> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
chat_container = st.container() |
|
|
|
with chat_container: |
|
|
|
for response_id, response in enumerate(st.session_state.responses): |
|
status = show_custom_response(response) |
|
|
|
|
|
if response["role"] == "assistant": |
|
feedback_key = f"feedback_{int(response_id/2)}" |
|
error = response.get("error", "") |
|
output = response.get("content", "") |
|
last_prompt = response.get("last_prompt", "") |
|
code = response.get("gen_code", "") |
|
|
|
if "feedback" in st.session_state.responses[response_id]: |
|
feedback_data = st.session_state.responses[response_id]["feedback"] |
|
st.markdown(f""" |
|
<div class='feedback-section'> |
|
<strong>Your Feedback:</strong> {feedback_data.get('score', '')} |
|
{f"- {feedback_data.get('text', '')}" if feedback_data.get('text') else ""} |
|
</div> |
|
""", unsafe_allow_html=True) |
|
else: |
|
|
|
st.markdown("---") |
|
st.markdown("**How was this response?**") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
thumbs_up = st.button("π Helpful", key=f"{feedback_key}_up", use_container_width=True) |
|
with col2: |
|
thumbs_down = st.button("π Not Helpful", key=f"{feedback_key}_down", use_container_width=True) |
|
|
|
if thumbs_up or thumbs_down: |
|
thumbs = "π Helpful" if thumbs_up else "π Not Helpful" |
|
comments = st.text_area( |
|
"Tell us more (optional):", |
|
key=f"{feedback_key}_comments", |
|
placeholder="What could be improved? Any suggestions?", |
|
max_chars=500 |
|
) |
|
|
|
if st.button("Submit Feedback", key=f"{feedback_key}_submit"): |
|
feedback = {"score": thumbs, "text": comments} |
|
|
|
|
|
if upload_feedback(feedback, error, output, last_prompt, code, status or {}): |
|
st.session_state.responses[response_id]["feedback"] = feedback |
|
time.sleep(1) |
|
st.rerun() |
|
else: |
|
st.error("Failed to submit feedback. Please try again.") |
|
|
|
|
|
if st.session_state.get("processing"): |
|
show_processing_indicator( |
|
st.session_state.get("current_model", "Unknown"), |
|
st.session_state.get("current_question", "Processing...") |
|
) |
|
|
|
|
|
prompt = st.chat_input("Ask me anything about air quality!", key="main_chat") |
|
|
|
|
|
if selected_prompt: |
|
prompt = selected_prompt |
|
|
|
|
|
if prompt and not st.session_state.get("processing"): |
|
|
|
if "last_prompt" in st.session_state: |
|
last_prompt = st.session_state["last_prompt"] |
|
last_model_name = st.session_state.get("last_model_name", "") |
|
if (prompt == last_prompt) and (model_name == last_model_name): |
|
prompt = None |
|
|
|
if prompt: |
|
|
|
user_response = get_from_user(prompt) |
|
st.session_state.responses.append(user_response) |
|
|
|
|
|
st.session_state.processing = True |
|
st.session_state.current_model = model_name |
|
st.session_state.current_question = prompt |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
if st.session_state.get("processing"): |
|
prompt = st.session_state.get("current_question") |
|
model_name = st.session_state.get("current_model") |
|
|
|
try: |
|
response = ask_question(model_name=model_name, question=prompt) |
|
|
|
if not isinstance(response, dict): |
|
response = { |
|
"role": "assistant", |
|
"content": "Error: Invalid response format", |
|
"gen_code": "", |
|
"ex_code": "", |
|
"last_prompt": prompt, |
|
"error": "Invalid response format" |
|
} |
|
|
|
response.setdefault("role", "assistant") |
|
response.setdefault("content", "No content generated") |
|
response.setdefault("gen_code", "") |
|
response.setdefault("ex_code", "") |
|
response.setdefault("last_prompt", prompt) |
|
response.setdefault("error", None) |
|
|
|
except Exception as e: |
|
response = { |
|
"role": "assistant", |
|
"content": f"Sorry, I encountered an error: {str(e)}", |
|
"gen_code": "", |
|
"ex_code": "", |
|
"last_prompt": prompt, |
|
"error": str(e) |
|
} |
|
|
|
st.session_state.responses.append(response) |
|
st.session_state["last_prompt"] = prompt |
|
st.session_state["last_model_name"] = model_name |
|
st.session_state.processing = False |
|
|
|
|
|
if "current_model" in st.session_state: |
|
del st.session_state.current_model |
|
if "current_question" in st.session_state: |
|
del st.session_state.current_question |
|
|
|
st.rerun() |
|
|
|
|
|
if st.session_state.get("processing"): |
|
st.markdown("<script>scrollToBottom();</script>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hf_token and hf_token.strip(): |
|
st.markdown("### π Session Stats") |
|
total_interactions = len([r for r in st.session_state.get("responses", []) if r.get("role") == "assistant"]) |
|
st.metric("Interactions", total_interactions) |
|
|
|
feedbacks_given = len([r for r in st.session_state.get("responses", []) if r.get("role") == "assistant" and "feedback" in r]) |
|
st.metric("Feedbacks Given", feedbacks_given) |
|
|
|
|
|
st.markdown(""" |
|
<div style='text-align: center; margin-top: 3rem; padding: 2rem; background: rgba(255,255,255,0.1); border-radius: 15px;'> |
|
<h3>Together for Cleaner Air</h3> |
|
<p>VayuChat - Empowering environmental awareness through AI</p> |
|
<small>Β© 2024 IIT Gandhinagar Sustainability Lab</small> |
|
</div> |
|
""", unsafe_allow_html=True) |