|
import spaces |
|
import subprocess |
|
import os |
|
import torch |
|
from dotenv import load_dotenv |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.schema.runnable import RunnablePassthrough |
|
from langchain.schema.output_parser import StrOutputParser |
|
from qdrant_client import QdrantClient, models |
|
from langchain_openai import ChatOpenAI |
|
import gradio as gr |
|
import logging |
|
from typing import List, Tuple, Generator |
|
from dataclasses import dataclass |
|
from datetime import datetime |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langchain_huggingface.llms import HuggingFacePipeline |
|
from langchain_cerebras import ChatCerebras |
|
from queue import Queue |
|
from threading import Thread |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
@dataclass |
|
class Message: |
|
role: str |
|
content: str |
|
timestamp: str |
|
|
|
class ChatHistory: |
|
def __init__(self): |
|
self.messages: List[Message] = [] |
|
|
|
def add_message(self, role: str, content: str): |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
self.messages.append(Message(role=role, content=content, timestamp=timestamp)) |
|
|
|
def get_formatted_history(self, max_messages: int = 10) -> str: |
|
recent_messages = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages |
|
formatted_history = "\n".join([ |
|
f"{msg.role}: {msg.content}" for msg in recent_messages |
|
]) |
|
return formatted_history |
|
|
|
def clear(self): |
|
self.messages = [] |
|
|
|
|
|
load_dotenv() |
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
C_apikey = os.getenv("C_apikey") |
|
OPENAPI_KEY = os.getenv("OPENAPI_KEY") |
|
|
|
if not HF_TOKEN: |
|
logger.error("HF_TOKEN is not set in the environment variables.") |
|
exit(1) |
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
try: |
|
client = QdrantClient( |
|
url=os.getenv("QDRANT_URL"), |
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
prefer_grpc=False |
|
) |
|
except Exception as e: |
|
logger.error("Failed to connect to Qdrant.") |
|
exit(1) |
|
|
|
collection_name = "mawared" |
|
|
|
try: |
|
client.create_collection( |
|
collection_name=collection_name, |
|
vectors_config=models.VectorParams( |
|
size=384, |
|
distance=models.Distance.COSINE |
|
) |
|
) |
|
except Exception as e: |
|
if "already exists" not in str(e): |
|
logger.error(f"Error creating collection: {e}") |
|
exit(1) |
|
|
|
db = Qdrant( |
|
client=client, |
|
collection_name=collection_name, |
|
embeddings=embeddings, |
|
) |
|
|
|
retriever = db.as_retriever( |
|
search_type="similarity", |
|
search_kwargs={"k": 5} |
|
) |
|
|
|
llm = ChatCerebras( |
|
model="llama-3.3-70b", |
|
api_key=C_apikey, |
|
streaming=True |
|
) |
|
|
|
template = """ |
|
You are a Friendly assistant specializing in the Mawared HR System. |
|
Your role is to provide precise and contextually relevant answers based on the retrieved context and chat history. |
|
Your top priority is user experience and satisfaction, only answer questions based on Mawared HR system and ignore everything else. |
|
|
|
Key Responsibilities: |
|
|
|
Use the given chat history and retrieved context to craft accurate and detailed responses. |
|
If necessary, ask specific and targeted clarifying questions to gather more information. |
|
Present step-by-step instructions in a clear, numbered format when applicable. |
|
If you think you will not be able to provide a clear answer based on the user question , ask a clariifying question and ask for more details. |
|
|
|
Previous Conversation: {chat_history} |
|
Retrieved Context: {context} |
|
Current Question: {question} |
|
Answer: |
|
""" |
|
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
def create_rag_chain(chat_history: str): |
|
chain = ( |
|
{ |
|
"context": retriever, |
|
"question": RunnablePassthrough(), |
|
"chat_history": lambda x: chat_history |
|
} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
return chain |
|
|
|
chat_history = ChatHistory() |
|
|
|
def process_stream(stream_queue: Queue, history: List[List[str]]) -> Generator[List[List[str]], None, None]: |
|
"""Process the streaming response and update the chat interface""" |
|
current_response = "" |
|
|
|
while True: |
|
chunk = stream_queue.get() |
|
if chunk is None: |
|
break |
|
|
|
current_response += chunk |
|
new_history = history.copy() |
|
new_history[-1][1] = current_response |
|
yield new_history |
|
|
|
@spaces.GPU() |
|
def ask_question_gradio(question: str, history: List[List[str]]) -> Generator[tuple, None, None]: |
|
try: |
|
if history is None: |
|
history = [] |
|
|
|
chat_history.add_message("user", question) |
|
formatted_history = chat_history.get_formatted_history() |
|
rag_chain = create_rag_chain(formatted_history) |
|
|
|
|
|
history.append([question, ""]) |
|
|
|
|
|
stream_queue = Queue() |
|
|
|
|
|
def stream_processor(): |
|
try: |
|
for chunk in rag_chain.stream(question): |
|
stream_queue.put(chunk) |
|
stream_queue.put(None) |
|
except Exception as e: |
|
logger.error(f"Streaming error: {e}") |
|
stream_queue.put(None) |
|
|
|
|
|
Thread(target=stream_processor).start() |
|
|
|
|
|
response = "" |
|
for updated_history in process_stream(stream_queue, history): |
|
response = updated_history[-1][1] |
|
yield "", updated_history |
|
|
|
|
|
chat_history.add_message("assistant", response) |
|
|
|
except Exception as e: |
|
logger.error(f"Error during question processing: {e}") |
|
if not history: |
|
history = [] |
|
history.append([question, "An error occurred. Please try again later."]) |
|
yield "", history |
|
|
|
def clear_chat(): |
|
chat_history.clear() |
|
return [], "" |
|
|
|
|
|
with gr.Blocks(theme='Hev832/Applio') as iface: |
|
gr.Image("Image.jpg", width=750, height=300, show_label=False, show_download_button=False) |
|
gr.Markdown("# Mawared HR Assistant 2.6.4") |
|
gr.Markdown('### Instructions') |
|
gr.Markdown("Ask a question about MawaredHR and get a detailed answer, if you get an error try again with same prompt, its an Api issue and we are working on it 😀") |
|
|
|
chatbot = gr.Chatbot( |
|
height=750, |
|
show_label=False, |
|
bubble_full_width=False, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=20): |
|
question_input = gr.Textbox( |
|
label="Ask a question:", |
|
placeholder="Type your question here...", |
|
show_label=False |
|
) |
|
with gr.Column(scale=4): |
|
with gr.Row(): |
|
with gr.Column(): |
|
send_button = gr.Button("Send", variant="primary", size="sm") |
|
clear_button = gr.Button("Clear Chat", size="sm") |
|
|
|
|
|
submit_events = [question_input.submit, send_button.click] |
|
for submit_event in submit_events: |
|
submit_event( |
|
ask_question_gradio, |
|
inputs=[question_input, chatbot], |
|
outputs=[question_input, chatbot] |
|
) |
|
|
|
clear_button.click( |
|
clear_chat, |
|
outputs=[chatbot, question_input] |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |