|
import gradio as gr |
|
from datasets import load_dataset, Dataset |
|
|
|
|
|
import os |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import torch |
|
from threading import Thread |
|
from ragatouille import RAGPretrainedModel |
|
from datasets import load_dataset |
|
|
|
|
|
token = os.environ["HF_TOKEN"] |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"google/gemma-7b-it", |
|
|
|
torch_dtype=torch.float16, |
|
token=token, |
|
) |
|
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token) |
|
device = torch.device("cuda") |
|
model = model.to(device) |
|
RAG = RAGPretrainedModel.from_pretrained("mixedbread-ai/mxbai-colbert-v1") |
|
|
|
|
|
|
|
|
|
dataset = load_dataset( |
|
"wikimedia/wikipedia", "20231101.en", split="train", streaming=True |
|
) |
|
|
|
data = Dataset.from_dict({}) |
|
i = 0 |
|
for i, entry in enumerate(dataset): |
|
|
|
|
|
data = data.add_item(entry) |
|
if i == 3000: |
|
break |
|
|
|
del dataset |
|
|
|
|
|
documents = data["text"] |
|
RAG.index(documents, index_name="wikipedia", use_faiss=True) |
|
|
|
del documents |
|
|
|
def search(query, k: int = 5): |
|
results = RAG.search(query, k=k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return [result["passage_id"] for result in results] |
|
|
|
|
|
def prepare_prompt(query, indexes,data = data): |
|
prompt = ( |
|
f"Query: {query}\nContinue to answer the query by using the Search Results:\n" |
|
) |
|
titles = [] |
|
urls = [] |
|
for i in indexes: |
|
title = entry["title"][i] |
|
text = entry["text"][i] |
|
url = entry["url"][i] |
|
titles.append(title) |
|
urls.append(url) |
|
prompt += f"Title: {title}, Text: {text}\n" |
|
return prompt, (titles,urls) |
|
|
|
|
|
@spaces.GPU |
|
def talk(message, history): |
|
indexes = search(message) |
|
message,metadata = prepare_prompt(message, indexes) |
|
resources = "\nRESOURCES:\n" |
|
for title,url in metadata: |
|
resources += f"[{title}]({url}), " |
|
chat = [] |
|
for item in history: |
|
chat.append({"role": "user", "content": item[0]}) |
|
if item[1] is not None: |
|
cleaned_past = item[1].split("\nRESOURCES:\n")[0] |
|
chat.append({"role": "assistant", "content": cleaned_past}) |
|
chat.append({"role": "user", "content": message}) |
|
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
|
|
model_inputs = tok([messages], return_tensors="pt").to(device) |
|
streamer = TextIteratorStreamer( |
|
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True |
|
) |
|
generate_kwargs = dict( |
|
model_inputs, |
|
streamer=streamer, |
|
max_new_tokens=1024, |
|
do_sample=True, |
|
top_p=0.95, |
|
top_k=1000, |
|
temperature=0.75, |
|
num_beams=1, |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
|
|
partial_text = "" |
|
for new_text in streamer: |
|
partial_text += new_text |
|
yield partial_text |
|
partial_text += resources |
|
yield partial_text |
|
|
|
|
|
TITLE = "RAG" |
|
|
|
DESCRIPTION = """ |
|
## Resources used to build this project |
|
* https://huggingface.co/mixedbread-ai/mxbai-colbert-large-v1 |
|
* me π |
|
## Models |
|
the models used in this space are : |
|
* google/gemma-7b-it |
|
* mixedbread-ai/mxbai-colbert-v1 |
|
""" |
|
|
|
demo = gr.ChatInterface( |
|
fn=talk, |
|
chatbot=gr.Chatbot( |
|
show_label=True, |
|
show_share_button=True, |
|
show_copy_button=True, |
|
likeable=True, |
|
layout="bubble", |
|
bubble_full_width=False, |
|
), |
|
theme="Soft", |
|
examples=[["what is machine learning"]], |
|
title=TITLE, |
|
description=DESCRIPTION, |
|
) |
|
demo.launch() |
|
|