|
import gradio as gr |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers.quantization import quantize_embeddings |
|
import faiss |
|
from usearch.index import Index |
|
import os |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import torch |
|
from threading import Thread |
|
|
|
token = os.environ["HF_TOKEN"] |
|
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it", |
|
|
|
torch_dtype=torch.float16, |
|
token=token) |
|
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token) |
|
device = torch.device('cuda') |
|
model = model.to(device) |
|
|
|
|
|
title_text_dataset = load_dataset( |
|
"mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4 |
|
).select_columns(["title", "text"]) |
|
|
|
|
|
int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True) |
|
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary( |
|
"wikipedia_ubinary_faiss_50m.index" |
|
) |
|
binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary( |
|
"wikipedia_ubinary_ivf_faiss_50m.index" |
|
) |
|
|
|
|
|
model = SentenceTransformer( |
|
"mixedbread-ai/mxbai-embed-large-v1", |
|
prompts={ |
|
"retrieval": "Represent this sentence for searching relevant passages: ", |
|
}, |
|
default_prompt_name="retrieval", |
|
) |
|
|
|
|
|
def search( |
|
query, top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False |
|
): |
|
|
|
query_embedding = model.encode(query) |
|
|
|
|
|
query_embedding_ubinary = quantize_embeddings( |
|
query_embedding.reshape(1, -1), "ubinary" |
|
) |
|
|
|
|
|
index = binary_ivf if use_approx else binary_index |
|
_scores, binary_ids = index.search( |
|
query_embedding_ubinary, top_k * rescore_multiplier |
|
) |
|
binary_ids = binary_ids[0] |
|
|
|
|
|
int8_embeddings = int8_view[binary_ids].astype(int) |
|
|
|
|
|
scores = query_embedding @ int8_embeddings.T |
|
|
|
|
|
indices = scores.argsort()[::-1][:top_k] |
|
top_k_indices = binary_ids[indices] |
|
top_k_scores = scores[indices] |
|
top_k_titles, top_k_texts = zip( |
|
*[ |
|
(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) |
|
for idx in top_k_indices.tolist() |
|
] |
|
) |
|
df = { |
|
"Score": [round(value, 2) for value in top_k_scores], |
|
"Title": top_k_titles, |
|
"Text": top_k_texts, |
|
} |
|
|
|
return df |
|
|
|
def prepare_prompt(query, df): |
|
prompt = f"Query: {query}\nContinue to answer the query by using the Search Results:\n" |
|
for data in df : |
|
title = data["Title"] |
|
text = data["Text"] |
|
prompt+=f"Title: {title}, Text: {text}\n" |
|
return prompt |
|
|
|
@spaces.GPU |
|
def talk(message, history): |
|
df = search(message) |
|
message = prepare_prompt(message,df) |
|
resources = "\nRESOURCES:\n" |
|
for title in df["Title"][:3] : |
|
resources+=f"[{title}](https://huggingface.co/spaces/not-lain/RAG), " |
|
chat = [] |
|
for item in history: |
|
chat.append({"role": "user", "content": item[0]}) |
|
if item[1] is not None: |
|
cleaned_past = item[1].split("\nRESOURCES:\n")[0] |
|
chat.append({"role": "assistant", "content": cleaned_past}) |
|
chat.append({"role": "user", "content": message}) |
|
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
|
|
model_inputs = tok([messages], return_tensors="pt").to(device) |
|
streamer = TextIteratorStreamer( |
|
tok, timeout=10., skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
model_inputs, |
|
streamer=streamer, |
|
max_new_tokens=1024, |
|
do_sample=True, |
|
top_p=0.95, |
|
top_k=1000, |
|
temperature=0.75, |
|
num_beams=1, |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
|
|
partial_text = "" |
|
for new_text in streamer: |
|
partial_text += new_text |
|
yield partial_text |
|
partial_text+= resources |
|
yield partial_text |
|
|
|
|
|
|
|
|
|
|
|
TITLE = "RAG" |
|
|
|
DESCRIPTION = """ |
|
## Resources used to build this project |
|
* https://huggingface.co/learn/cookbook/rag_with_hugging_face_gemma_mongodb |
|
* https://huggingface.co/spaces/sentence-transformers/quantized-retrieval |
|
## Retrival paramaters |
|
```python |
|
top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False |
|
``` |
|
## Models |
|
the models used in this space are : |
|
* google/gemma-7b-it |
|
* mixedbread-ai/wikipedia-data-en-2023-11 |
|
""" |
|
|
|
demo = gr.ChatInterface(fn=talk, |
|
chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False), |
|
theme="Soft", |
|
examples=[["Write me a poem about Machine Learning."]], |
|
title="Text Streaming") |
|
demo.launch() |
|
|