Spaces:
Build error
Build error
| import os | |
| import time | |
| import pdfplumber | |
| import docx | |
| import nltk | |
| import gradio as gr | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_text_splitters import TokenTextSplitter | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer | |
| from nltk import sent_tokenize | |
| from typing import List, Tuple | |
| from transformers import AutoModel, AutoTokenizer | |
| #import spacy | |
| #spacy.cli.download("en_core_web_sm") # Ensure the model is available | |
| #nlp = spacy.load("en_core_web_sm") # Load the model | |
| # Ensure nltk sentence tokenizer is downloaded | |
| nltk.download('punkt') | |
| FILES_DIR = './files' | |
| # Supported embedding models | |
| MODELS = { | |
| 'e5-base': "danielheinz/e5-base-sts-en-de", | |
| 'multilingual-e5-base': "multilingual-e5-base", | |
| 'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2", | |
| 'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2", | |
| 'gte-large': "gte-large", | |
| 'gbert-base': "gbert-base" | |
| } | |
| class FileHandler: | |
| def extract_text(file_path): | |
| ext = os.path.splitext(file_path)[-1].lower() | |
| if ext == '.pdf': | |
| return FileHandler._extract_from_pdf(file_path) | |
| elif ext == '.docx': | |
| return FileHandler._extract_from_docx(file_path) | |
| elif ext == '.txt': | |
| return FileHandler._extract_from_txt(file_path) | |
| else: | |
| raise ValueError(f"Unsupported file type: {ext}") | |
| def _extract_from_pdf(file_path): | |
| with pdfplumber.open(file_path) as pdf: | |
| return ' '.join([page.extract_text() for page in pdf.pages]) | |
| def _extract_from_docx(file_path): | |
| doc = docx.Document(file_path) | |
| return ' '.join([para.text for para in doc.paragraphs]) | |
| def _extract_from_txt(file_path): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| class EmbeddingModel: | |
| def __init__(self, model_name, max_tokens=None): | |
| self.model = HuggingFaceEmbeddings(model_name=model_name) | |
| self.max_tokens = max_tokens | |
| def embed(self, chunks: List[str]): | |
| # Embed the list of chunks | |
| return self.model.embed_documents(chunks) | |
| def process_files(model_name, split_strategy, chunk_size=500, overlap_size=50, max_tokens=None): | |
| # File processing | |
| text = "" | |
| for file in os.listdir(FILES_DIR): | |
| file_path = os.path.join(FILES_DIR, file) | |
| text += FileHandler.extract_text(file_path) | |
| # Split text into chunks | |
| if split_strategy == 'token': | |
| splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size) | |
| else: | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size) | |
| chunks = splitter.split_text(text) | |
| # Embed chunks, not the full text | |
| model = EmbeddingModel(MODELS[model_name], max_tokens=max_tokens) | |
| embeddings = model.embed(chunks) | |
| return embeddings, chunks | |
| def search_embeddings(query, model_name, top_k): | |
| model = HuggingFaceEmbeddings(model_name=MODELS[model_name]) | |
| embeddings = model.embed_query(query) | |
| # Perform FAISS or other similarity-based search over embeddings | |
| # This part requires you to build and search a FAISS index with embeddings | |
| return embeddings # You would likely return the top-k results here | |
| def calculate_statistics(embeddings): | |
| # Return time taken, token count, etc. | |
| return {"tokens": len(embeddings), "time_taken": time.time()} | |
| import shutil | |
| def upload_file(file, model_name, split_strategy, chunk_size, overlap_size, max_tokens, query, top_k): | |
| # Ensure the correct type for chunk_size and overlap_size | |
| try: | |
| chunk_size = int(chunk_size) if chunk_size else 100 | |
| overlap_size = int(overlap_size) if overlap_size else 0 | |
| except ValueError: | |
| return {"error": "Chunk size and overlap size must be valid integers."} | |
| # Handle file upload using the Gradio file object | |
| file_path = file.name # Get the file path from Gradio file object | |
| # Copy the uploaded file content to a local directory | |
| destination_path = os.path.join(FILES_DIR, os.path.basename(file_path)) | |
| shutil.copyfile(file_path, destination_path) # Use shutil to copy the file | |
| # Process files and get embeddings | |
| embeddings, chunks = process_files(model_name, split_strategy, chunk_size, overlap_size, max_tokens) | |
| # Perform search | |
| results = search_embeddings(query, model_name, top_k) | |
| # Calculate statistics | |
| stats = calculate_statistics(embeddings) | |
| return {"results": results, "stats": stats} | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=upload_file, | |
| inputs=[ | |
| gr.File(label="Upload File"), | |
| gr.Textbox(label="Search Query"), | |
| gr.Dropdown(choices=list(MODELS.keys()), label="Embedding Model"), | |
| gr.Radio(choices=["sentence", "recursive"], label="Split Strategy"), | |
| gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"), # Ensure type is int | |
| gr.Slider(0, 100, step=10, value=50, label="Overlap Size"), # Ensure type is int | |
| gr.Slider(50, 500, step=50, value=200, label="Max Tokens"), # Ensure type is int | |
| gr.Slider(1, 10, step=1, value=5, label="Top K") # Ensure type is int | |
| ], | |
| outputs="json" | |
| ) | |
| iface.launch() | |