Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import streamlit as st | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| from langchain_community.llms import HuggingFacePipeline | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from dotenv import load_dotenv | |
| # Set Streamlit page configuration | |
| st.set_page_config(page_title="Chat with Notes and AI", page_icon=":books:", layout="wide") | |
| # Load environment variables | |
| load_dotenv() | |
| # Optimized pipeline setup | |
| def load_pipeline(): | |
| # Use a smaller model for faster performance | |
| model_name = "databricks/dolly-v2-1b" # Switch to a lighter model | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, # Use float32 for CPU compatibility | |
| device_map="auto", # Automatically map devices | |
| trust_remote_code=True | |
| ) | |
| # Return text-generation pipeline with full-text output | |
| return pipeline( | |
| task="text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| torch_dtype=torch.float32, # Ensure compatibility with CPU | |
| device_map="auto", | |
| return_full_text=True, | |
| max_new_tokens=100 # Limit response length | |
| ) | |
| # Initialize pipeline | |
| generate_text = load_pipeline() | |
| # LangChain Integration | |
| hf_pipeline = HuggingFacePipeline(pipeline=generate_text) | |
| # Templates for prompts | |
| prompt = PromptTemplate(input_variables=["instruction"], template="{instruction}") | |
| prompt_with_context = PromptTemplate( | |
| input_variables=["instruction", "context"], | |
| template="{instruction}\n\nInput:\n{context}" | |
| ) | |
| # LangChain LLM chains | |
| llm_chain = LLMChain(llm=hf_pipeline, prompt=prompt) | |
| llm_context_chain = LLMChain(llm=hf_pipeline, prompt=prompt_with_context) | |
| # Extract content from .txt files | |
| def get_text_files_content(folder): | |
| text = "" | |
| for filename in os.listdir(folder): | |
| if filename.endswith('.txt'): | |
| with open(os.path.join(folder, filename), 'r', encoding='utf-8') as file: | |
| text += file.read() + "\n" | |
| return text | |
| # Convert text into chunks for vectorization | |
| def get_chunks(raw_text): | |
| from langchain.text_splitter import CharacterTextSplitter | |
| text_splitter = CharacterTextSplitter( | |
| separator="\n", | |
| chunk_size=500, # Smaller chunks for faster processing | |
| chunk_overlap=50 # Minimal overlap | |
| ) | |
| return text_splitter.split_text(raw_text) | |
| # Create FAISS vectorstore for embeddings | |
| def get_vectorstore(chunks): | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", # Lightweight embeddings | |
| model_kwargs={'device': 'cpu'} # Ensure embeddings run on CPU | |
| ) | |
| return FAISS.from_texts(texts=chunks, embedding=embeddings) | |
| # Handle user queries | |
| def handle_question(question, vectorstore=None): | |
| if vectorstore: | |
| # Retrieve the most relevant chunk | |
| documents = vectorstore.similarity_search(question, k=1) # Retrieve fewer chunks | |
| context = "\n".join([doc.page_content for doc in documents])[:500] # Short context for efficiency | |
| if context: | |
| return llm_context_chain.predict(instruction=question, context=context).strip() | |
| # Fallback to instruction-only chain if no context | |
| return llm_chain.predict(instruction=question).strip() | |
| def main(): | |
| st.title("Chat with Notes :books:") | |
| # Session state for vectorstore | |
| if "vectorstore" not in st.session_state: | |
| st.session_state.vectorstore = None | |
| # Data folders | |
| data_folder = "data" # Folder for Current Affairs | |
| essay_folder = "essays" # Folder for Essays | |
| # Content type selection | |
| content_type = st.sidebar.radio("Select Content Type:", ["Current Affairs", "Essays"]) | |
| # Subjects based on content type | |
| if content_type == "Current Affairs": | |
| subjects = [f for f in os.listdir(data_folder) if os.path.isdir(os.path.join(data_folder, f))] if os.path.exists(data_folder) else [] | |
| else: | |
| subjects = [f.replace(".txt", "") for f in os.listdir(essay_folder) if f.endswith('.txt')] if os.path.exists(essay_folder) else [] | |
| # Subject selection | |
| selected_subject = st.sidebar.selectbox("Select a Subject:", subjects) | |
| # Load content based on selection | |
| raw_text = "" | |
| if content_type == "Current Affairs" and selected_subject: | |
| subject_folder = os.path.join(data_folder, selected_subject) | |
| raw_text = get_text_files_content(subject_folder) | |
| elif content_type == "Essays" and selected_subject: | |
| subject_file = os.path.join(essay_folder, f"{selected_subject}.txt") | |
| if os.path.exists(subject_file): | |
| with open(subject_file, "r", encoding="utf-8") as file: | |
| raw_text = file.read() | |
| # Display preview and create vectorstore | |
| if raw_text: | |
| st.subheader("Preview of Notes") | |
| st.text_area("Preview Content:", value=raw_text[:1000], height=300, disabled=True) | |
| if "vectorstore" not in st.session_state or st.session_state.vectorstore is None: | |
| chunks = get_chunks(raw_text) | |
| st.session_state.vectorstore = get_vectorstore(chunks) | |
| else: | |
| st.warning("No content available for the selected subject.") | |
| # Question and response | |
| st.subheader("Ask Your Question") | |
| question = st.text_input("Ask a question about your selected subject:") | |
| if question: | |
| if st.session_state.vectorstore: | |
| response = handle_question(question, st.session_state.vectorstore) | |
| st.subheader("Answer:") | |
| st.write(response or "No response found.") | |
| else: | |
| st.warning("Please load the content for the selected subject before asking a question.") | |
| if __name__ == "__main__": | |
| main() | |