Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
import faiss | |
import torch | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import os | |
import spaces | |
def claim_gpu(): | |
# Dummy function to make Spaces detect GPU usage | |
pass | |
claim_gpu() | |
# Login automatically if HF_TOKEN is present | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
from huggingface_hub import login | |
login(token=hf_token) | |
# Load corpus | |
print("Loading dataset...") | |
dataset = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus") | |
corpus = [item for item in dataset["passages"]] | |
# Embedding model | |
print("Encoding corpus...") | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True, device='cpu') | |
corpus_embeddings_np = corpus_embeddings.numpy() | |
# FAISS index | |
index = faiss.IndexFlatL2(corpus_embeddings_np.shape[1]) | |
index.add(corpus_embeddings_np) | |
# Reranker model | |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
# Generator (choose one: local HF model or OpenAI) | |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") | |
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", device_map="auto", torch_dtype=torch.float16) | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150) | |
def rag_pipeline(query): | |
# Embed query | |
query_embedding = embedder.encode([query], convert_to_tensor=True, device='cpu').numpy() | |
# Retrieve top-k from FAISS | |
D, I = index.search(query_embedding, k=5) | |
retrieved_docs = [corpus[idx] for idx in I[0]] | |
# Rerank | |
rerank_pairs = [[query, doc] for doc in retrieved_docs] | |
scores = reranker.predict(rerank_pairs) | |
reranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)] | |
# Combine for context | |
context = "\n\n".join(reranked_docs[:2]) | |
prompt = f"""Answer the following question using the provided context.\n\nContext:\n{context}\n\nQuestion: {query}\nAnswer:""" | |
# Generate | |
response = generator(prompt)[0]["generated_text"] | |
return response.split("Answer:")[-1].strip() | |
# Gradio UI | |
iface = gr.Interface(fn=rag_pipeline, | |
inputs=gr.Textbox(lines=2, placeholder="Ask something..."), | |
outputs="text", | |
title="Mini RAG Wikipedia Demo", | |
description="Retrieval-Augmented Generation on a small Wikipedia subset.") | |
iface.launch() | |