Spaces:
Sleeping
Sleeping
import os | |
import json | |
import re | |
import gradio as gr | |
import requests | |
import random | |
import urllib.parse | |
from tempfile import NamedTemporaryFile | |
from bs4 import BeautifulSoup | |
from typing import List | |
from pydantic import BaseModel, Field | |
from huggingface_hub import InferenceApi | |
from duckduckgo_search import DDGS | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFaceHub | |
from langchain_core.documents import Document | |
from sentence_transformers import SentenceTransformer | |
from llama_parse import LlamaParse | |
# Environment variables and configurations | |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN") | |
llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY") | |
# Initialize SentenceTransformer and LlamaParse | |
sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2') | |
llama_parser = LlamaParse( | |
api_key=llama_cloud_api_key, | |
result_type="markdown", | |
num_workers=4, | |
verbose=True, | |
language="en", | |
) | |
def load_document(file: NamedTemporaryFile, parser: str = "pypdf") -> List[Document]: | |
if parser == "pypdf": | |
loader = PyPDFLoader(file.name) | |
return loader.load_and_split() | |
elif parser == "llamaparse": | |
try: | |
documents = llama_parser.load_data(file.name) | |
return [Document(page_content=doc.text, metadata={"source": file.name}) for doc in documents] | |
except Exception as e: | |
print(f"Error using Llama Parse: {str(e)}") | |
print("Falling back to PyPDF parser") | |
loader = PyPDFLoader(file.name) | |
return loader.load_and_split() | |
else: | |
raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.") | |
def update_vectors(files, parser): | |
if not files: | |
return "Please upload at least one PDF file." | |
embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
total_chunks = 0 | |
all_data = [] | |
for file in files: | |
data = load_document(file, parser) | |
all_data.extend(data) | |
total_chunks += len(data) | |
if os.path.exists("faiss_database"): | |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) | |
database.add_documents(all_data) | |
else: | |
database = FAISS.from_documents(all_data, embed) | |
database.save_local("faiss_database") | |
return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}." | |
def clear_cache(): | |
if os.path.exists("faiss_database"): | |
os.remove("faiss_database") | |
return "Cache cleared successfully." | |
else: | |
return "No cache to clear." | |
def get_model(temperature, top_p, repetition_penalty): | |
return HuggingFaceHub( | |
repo_id="mistralai/Mistral-7B-Instruct-v0.3", | |
model_kwargs={ | |
"temperature": temperature, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"max_length": 1000 | |
}, | |
huggingfacehub_api_token=huggingface_token | |
) | |
def duckduckgo_search(query): | |
with DDGS() as ddgs: | |
results = ddgs.text(query, max_results=5) | |
return results | |
def get_response_with_search(query, temperature, top_p, repetition_penalty, use_pdf=False): | |
model = get_model(temperature, top_p, repetition_penalty) | |
embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
if use_pdf and os.path.exists("faiss_database"): | |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) | |
retriever = database.as_retriever() | |
relevant_docs = retriever.get_relevant_documents(query) | |
context = "\n".join([f"Content: {doc.page_content}\nSource: {doc.metadata['source']}\n" for doc in relevant_docs]) | |
else: | |
search_results = duckduckgo_search(query) | |
context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n" | |
for result in search_results if 'body' in result) | |
prompt = f"""<s>[INST] Using the following context: | |
{context} | |
Write a detailed and complete research document that fulfills the following user request: '{query}' | |
After writing the document, please provide a list of sources used in your response. [/INST]""" | |
response = model(prompt) | |
main_content, sources = split_response(response) | |
return main_content, sources | |
def split_response(response): | |
parts = response.split("Sources:", 1) | |
main_content = parts[0].strip() | |
sources = parts[1].strip() if len(parts) > 1 else "" | |
return main_content, sources | |
def chatbot_interface(message, history, temperature, top_p, repetition_penalty, use_pdf): | |
main_content, sources = get_response_with_search(message, temperature, top_p, repetition_penalty, use_pdf) | |
formatted_response = f"{main_content}\n\nSources:\n{sources}" | |
return formatted_response | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# AI-powered Web Search and PDF Chat Assistant") | |
with gr.Row(): | |
file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"]) | |
parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="pypdf") | |
update_button = gr.Button("Upload PDF") | |
update_output = gr.Textbox(label="Update Status") | |
update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(label="Conversation") | |
msg = gr.Textbox(label="Ask a question") | |
submit_button = gr.Button("Submit") | |
with gr.Column(scale=1): | |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.7, step=0.1) | |
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05) | |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.1, step=0.1) | |
use_pdf = gr.Checkbox(label="Use PDF Documents", value=False) | |
def respond(message, chat_history, temperature, top_p, repetition_penalty, use_pdf): | |
bot_message = chatbot_interface(message, chat_history, temperature, top_p, repetition_penalty, use_pdf) | |
chat_history.append((message, bot_message)) | |
return "", chat_history | |
submit_button.click(respond, inputs=[msg, chatbot, temperature, top_p, repetition_penalty, use_pdf], outputs=[msg, chatbot]) | |
clear_button = gr.Button("Clear Cache") | |
clear_output = gr.Textbox(label="Cache Status") | |
clear_button.click(clear_cache, inputs=[], outputs=clear_output) | |
if __name__ == "__main__": | |
demo.launch() |