rag_chatbot / app.py
InspirationYF's picture
feat: add mistral
e150690
raw
history blame
2.6 kB
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
# model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="auto")
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
# # Check if a GPU is available
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
import gradio as gr
# You can use this section to suppress warnings generated by your code:
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
warnings.filterwarnings('ignore')
def get_llm():
model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
model.to('cuda')
return model
@spaces.GPU
def retriever_qa(file, query):
llm = get_llm()
# retriever_obj = retriever(file)
# qa = RetrievalQA.from_chain_type(llm=llm,
# chain_type="stuff",
# retriever=retriever_obj,
# return_source_documents=False)
# response = qa.invoke(query)
with open(file, 'r') as f:
first_line = f.readline()
messages = [
{"role": "user", "content": first_line}
]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
generated_ids = llm.generate(model_inputs, max_new_tokens=100, do_sample=True)
# tokenizer.batch_decode(generated_ids)[0]
response = tokenizer.batch_decode(generated_ids)[0]
# # Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
response = response + f". Using device: {device}"
return response
rag_application = gr.Interface(
fn=retriever_qa,
allow_flagging="never",
inputs=[
# gr.File(label="Upload PDF File", file_count="single", file_types=['.pdf'], type="filepath"), # Drag and drop file upload
gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"), # Drag and drop file upload
gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...")
],
outputs=gr.Textbox(label="Output"),
title="RAG Chatbot",
description="Upload a PDF document and ask any question. The chatbot will try to answer using the provided document. Using device: {device}"
)
rag_application.launch(share=True)