Spaces:
Runtime error
Runtime error
import spaces | |
import subprocess | |
import os | |
import torch | |
from dotenv import load_dotenv | |
from langchain_community.vectorstores import Qdrant | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain.schema.output_parser import StrOutputParser | |
from qdrant_client import QdrantClient, models | |
from langchain_openai import ChatOpenAI | |
import gradio as gr | |
import logging | |
from typing import List, Tuple, Generator | |
from dataclasses import dataclass | |
from datetime import datetime | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from langchain_huggingface.llms import HuggingFacePipeline | |
from langchain_cerebras import ChatCerebras | |
from queue import Queue | |
from threading import Thread | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class Message: | |
role: str | |
content: str | |
timestamp: str | |
class ChatHistory: | |
def __init__(self): | |
self.messages: List[Message] = [] | |
def add_message(self, role: str, content: str): | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
self.messages.append(Message(role=role, content=content, timestamp=timestamp)) | |
def get_formatted_history(self, max_messages: int = 10) -> str: | |
recent_messages = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages | |
formatted_history = "\n".join([ | |
f"{msg.role}: {msg.content}" for msg in recent_messages | |
]) | |
return formatted_history | |
def clear(self): | |
self.messages = [] | |
# Load environment variables and setup | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
C_apikey = os.getenv("C_apikey") | |
OPENAPI_KEY = os.getenv("OPENAPI_KEY") | |
if not HF_TOKEN: | |
logger.error("HF_TOKEN is not set in the environment variables.") | |
exit(1) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
try: | |
client = QdrantClient( | |
url=os.getenv("QDRANT_URL"), | |
api_key=os.getenv("QDRANT_API_KEY"), | |
prefer_grpc=False | |
) | |
except Exception as e: | |
logger.error("Failed to connect to Qdrant.") | |
exit(1) | |
collection_name = "mawared" | |
try: | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config=models.VectorParams( | |
size=384, | |
distance=models.Distance.COSINE | |
) | |
) | |
except Exception as e: | |
if "already exists" not in str(e): | |
logger.error(f"Error creating collection: {e}") | |
exit(1) | |
db = Qdrant( | |
client=client, | |
collection_name=collection_name, | |
embeddings=embeddings, | |
) | |
retriever = db.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 5} | |
) | |
llm = ChatCerebras( | |
model="llama-3.3-70b", | |
api_key=C_apikey, | |
streaming=True | |
) | |
template = """ | |
You are a Friendly assistant specializing in the Mawared HR System. | |
Your role is to provide precise and contextually relevant answers based on the retrieved context and chat history. | |
Your top priority is user experience and satisfaction, only answer questions based on Mawared HR system and ignore everything else. | |
Key Responsibilities: | |
Use the given chat history and retrieved context to craft accurate and detailed responses. | |
If necessary, ask specific and targeted clarifying questions to gather more information. | |
Present step-by-step instructions in a clear, numbered format when applicable. | |
If you think you will not be able to provide a clear answer based on the user question , ask a clariifying question and ask for more details. | |
Previous Conversation: {chat_history} | |
Retrieved Context: {context} | |
Current Question: {question} | |
Answer: | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
def create_rag_chain(chat_history: str): | |
chain = ( | |
{ | |
"context": retriever, | |
"question": RunnablePassthrough(), | |
"chat_history": lambda x: chat_history | |
} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
return chain | |
chat_history = ChatHistory() | |
def process_stream(stream_queue: Queue, history: List[List[str]]) -> Generator[List[List[str]], None, None]: | |
"""Process the streaming response and update the chat interface""" | |
current_response = "" | |
while True: | |
chunk = stream_queue.get() | |
if chunk is None: # Signal that streaming is complete | |
break | |
current_response += chunk | |
new_history = history.copy() | |
new_history[-1][1] = current_response # Update the assistant's message | |
yield new_history | |
def ask_question_gradio(question: str, history: List[List[str]]) -> Generator[tuple, None, None]: | |
try: | |
if history is None: | |
history = [] | |
chat_history.add_message("user", question) | |
formatted_history = chat_history.get_formatted_history() | |
rag_chain = create_rag_chain(formatted_history) | |
# Update history with user message and empty assistant message | |
history.append([question, ""]) # User message | |
# Create a queue for streaming responses | |
stream_queue = Queue() | |
# Function to process the stream in a separate thread | |
def stream_processor(): | |
try: | |
for chunk in rag_chain.stream(question): | |
stream_queue.put(chunk) | |
stream_queue.put(None) # Signal completion | |
except Exception as e: | |
logger.error(f"Streaming error: {e}") | |
stream_queue.put(None) | |
# Start streaming in a separate thread | |
Thread(target=stream_processor).start() | |
# Yield updates to the chat interface | |
response = "" | |
for updated_history in process_stream(stream_queue, history): | |
response = updated_history[-1][1] | |
yield "", updated_history | |
# Add final response to chat history | |
chat_history.add_message("assistant", response) | |
except Exception as e: | |
logger.error(f"Error during question processing: {e}") | |
if not history: | |
history = [] | |
history.append([question, "An error occurred. Please try again later."]) | |
yield "", history | |
def clear_chat(): | |
chat_history.clear() | |
return [], "" | |
# Gradio Interface | |
with gr.Blocks(theme='Yntec/HaleyCH_Theme_Orange_Green') as iface: | |
gr.Image("Image.jpg", width=750, height=300, show_label=False, show_download_button=False) | |
gr.Markdown("# Mawared HR Assistant 2.5.1") | |
gr.Markdown('### Instructions') | |
gr.Markdown("Ask a question about MawaredHR and get a detailed answer, if you get an error try again with same prompt, its an Api issue and we are working on it π") | |
chatbot = gr.Chatbot( | |
height=750, | |
show_label=False, | |
bubble_full_width=False, | |
) | |
with gr.Row(): | |
question_input = gr.Textbox( | |
label="Ask a question:", | |
placeholder="Type your question here...", | |
scale=30 | |
) | |
clear_button = gr.Button("Clear Chat", scale=1) | |
question_input.submit( | |
ask_question_gradio, | |
inputs=[question_input, chatbot], | |
outputs=[question_input, chatbot] | |
) | |
clear_button.click( | |
clear_chat, | |
outputs=[chatbot, question_input] | |
) | |
if __name__ == "__main__": | |
iface.launch() |