Spaces:
Runtime error
Runtime error
import spaces | |
import subprocess | |
subprocess.run( | |
'pip install flash-attn --no-build-isolation', | |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
shell=True | |
) | |
import os | |
import torch | |
from dotenv import load_dotenv | |
from langchain_community.vectorstores import Qdrant | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain.schema.output_parser import StrOutputParser | |
from qdrant_client import QdrantClient, models | |
from langchain_openai import ChatOpenAI | |
import gradio as gr | |
import logging | |
from typing import List, Tuple | |
from dataclasses import dataclass | |
from datetime import datetime | |
from transformers import AutoTokenizer, AutoModelForCausalLM ,pipeline | |
from langchain_huggingface.llms import HuggingFacePipeline | |
import re | |
from langchain_huggingface.llms import HuggingFacePipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline,BitsAndBytesConfig,TextIteratorStreamer | |
from langchain_cerebras import ChatCerebras | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class Message: | |
role: str | |
content: str | |
timestamp: str | |
class ChatHistory: | |
def __init__(self): | |
self.messages: List[Message] = [] | |
def add_message(self, role: str, content: str): | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
self.messages.append(Message(role=role, content=content, timestamp=timestamp)) | |
def get_formatted_history(self, max_messages: int = 10) -> str: | |
"""Returns the most recent conversation history formatted as a string""" | |
recent_messages = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages | |
formatted_history = "\n".join([ | |
f"{msg.role}: {msg.content}" for msg in recent_messages | |
]) | |
return formatted_history | |
def clear(self): | |
self.messages = [] | |
# Load environment variables | |
load_dotenv() | |
# HuggingFace API Token | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
C_apikey = os.getenv("C_apikey") | |
OPENAPI_KEY = os.getenv("OPENAPI_KEY") | |
if not HF_TOKEN: | |
logger.error("HF_TOKEN is not set in the environment variables.") | |
exit(1) | |
# HuggingFace Embeddings | |
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5") | |
# Qdrant Client Setup | |
try: | |
client = QdrantClient( | |
url=os.getenv("QDRANT_URL"), | |
api_key=os.getenv("QDRANT_API_KEY"), | |
prefer_grpc=False | |
) | |
except Exception as e: | |
logger.error("Failed to connect to Qdrant. Ensure QDRANT_URL and QDRANT_API_KEY are correctly set.") | |
exit(1) | |
# Define collection name | |
collection_name = "mawared" | |
# Try to create collection | |
try: | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config=models.VectorParams( | |
size=768, # GTE-large embedding size | |
distance=models.Distance.COSINE | |
) | |
) | |
logger.info(f"Created new collection: {collection_name}") | |
except Exception as e: | |
if "already exists" in str(e): | |
logger.info(f"Collection {collection_name} already exists, continuing...") | |
else: | |
logger.error(f"Error creating collection: {e}") | |
exit(1) | |
# Create Qdrant vector store | |
db = Qdrant( | |
client=client, | |
collection_name=collection_name, | |
embeddings=embeddings, | |
) | |
# Create retriever | |
retriever = db.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 3} | |
) | |
# Load model directly | |
# Set up the LLM | |
# llm = ChatOpenAI( | |
# base_url="https://api-inference.huggingface.co/v1/", | |
# temperature=0, | |
# api_key=HF_TOKEN, | |
# model="mistralai/Mistral-Nemo-Instruct-2407", | |
# max_tokens=None, | |
# timeout=None | |
# ) | |
# llm = ChatOpenAI( | |
# base_url="https://openrouter.ai/api/v1", | |
# temperature=0, | |
# api_key=ChatOpenAI, | |
# model="google/gemini-2.0-flash-thinking-exp:free", | |
# max_tokens=None, | |
# timeout=None, | |
# stream=True | |
# ) | |
llm = ChatCerebras( | |
model="llama-3.3-70b", | |
api_key=C_apikey, | |
stream=True | |
) | |
# quantization_config = BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_compute_dtype=torch.bfloat16, | |
# bnb_4bit_quant_type="nf4", | |
# bnb_4bit_use_double_quant=True | |
# ) | |
# model_id = "unsloth/phi-4" | |
# tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# model = AutoModelForCausalLM.from_pretrained( | |
# model_id, | |
# torch_dtype=torch.float16, | |
# device_map="cuda", | |
# attn_implementation="flash_attention_2", | |
# quantization_config=quantization_config | |
# ) | |
# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=8192 ) | |
# llm = HuggingFacePipeline(pipeline=pipe) | |
# Create prompt template with chat history | |
template = """ | |
You are an expert assistant specializing in the Mawared HR System. | |
Your task is to provide accurate and contextually relevant answers based on the provided context and chat history. | |
If you need more information, ask targeted clarifying questions. | |
Ensure you provide detailed Numbered step by step to the user and be very accurate. | |
Previous Conversation: | |
{chat_history} | |
Current Context: | |
{context} | |
Current Question: | |
{question} | |
Ask followup questions based on your provided asnwer to create a conversational flow, Only answer form the provided context and chat history , dont make up any information. | |
answer only and only from the given context and knowledgebase. | |
Ensure you dont mention where you got the information from and dont mention any pages or documents. | |
Esnure your answers are detailed and expressive. | |
Answer: | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
# Create the RAG chain with chat history | |
def create_rag_chain(chat_history: str): | |
chain = ( | |
{ | |
"context": retriever, | |
"question": RunnablePassthrough(), | |
"chat_history": lambda x: chat_history | |
} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
return chain | |
# Initialize chat history | |
chat_history = ChatHistory() | |
# Gradio Function | |
# @spaces.GPU() | |
def ask_question_gradio(question, history): | |
try: | |
# Add user question to chat history | |
chat_history.add_message("user", question) | |
# Get formatted history | |
formatted_history = chat_history.get_formatted_history() | |
# Create chain with current chat history | |
rag_chain = create_rag_chain(formatted_history) | |
# Generate response | |
response = "" | |
for chunk in rag_chain.stream(question): | |
response += chunk | |
# Add assistant response to chat history | |
chat_history.add_message("assistant", response) | |
# Update Gradio chat history | |
history.append({"role": "user", "content": question}) | |
history.append({"role": "assistant", "content": response}) | |
return "", history | |
except Exception as e: | |
logger.error(f"Error during question processing: {e}") | |
return "", history + [{"role": "assistant", "content": "An error occurred. Please try again later."}] | |
def clear_chat(): | |
chat_history.clear() | |
return [], "" | |
# Gradio Interface | |
with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
gr.Image("Image.jpg" , width=1200 , height=300 ,show_label=False, show_download_button=False) | |
gr.Markdown("# Mawared HR Assistant") | |
gr.Markdown('### Instructions') | |
gr.Markdown("The first question will always send out an error in chat , try again and the flow should continue normally , its an API issue and we are working on it") | |
chatbot = gr.Chatbot( | |
height=400, | |
show_label=False, | |
type="messages" # Using the new messages format | |
) | |
with gr.Row(): | |
question_input = gr.Textbox( | |
label="Ask a question:", | |
placeholder="Type your question here...", | |
scale=25 | |
) | |
clear_button = gr.Button("Clear Chat", scale=1) | |
question_input.submit( | |
ask_question_gradio, | |
inputs=[question_input, chatbot], | |
outputs=[question_input, chatbot] | |
) | |
clear_button.click( | |
clear_chat, | |
outputs=[chatbot, question_input] | |
) | |
# Launch the Gradio App | |
if __name__ == "__main__": | |
iface.launch() |