File size: 2,512 Bytes
de90557
 
4f65fc5
de90557
 
 
 
 
 
4f65fc5
 
 
 
 
 
 
 
 
 
de90557
 
 
 
 
 
 
4f65fc5
de90557
4f65fc5
 
 
 
 
 
 
 
de90557
 
 
 
 
 
 
 
 
4f65fc5
de90557
 
 
 
 
 
4f65fc5
 
 
 
 
 
 
 
 
 
 
 
 
4e96b59
4f65fc5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

# # Check if a GPU is available
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

import gradio as gr

# You can use this section to suppress warnings generated by your code:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
warnings.filterwarnings('ignore')

def get_llm():
    model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
    model.to('cuda')
    return model

@spaces.GPU
def retriever_qa(file, query):
    llm = get_llm()
    # retriever_obj = retriever(file)
    # qa = RetrievalQA.from_chain_type(llm=llm, 
    #                                 chain_type="stuff", 
    #                                 retriever=retriever_obj, 
    #                                 return_source_documents=False)
    # response = qa.invoke(query)
    with open(file, 'r') as f:
        first_line = f.readline()

    messages = [
        {"role": "user", "content": first_line}
    ]

    model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")

    generated_ids = llm.generate(model_inputs, max_new_tokens=100, do_sample=True)
    # tokenizer.batch_decode(generated_ids)[0]
    
    response = tokenizer.batch_decode(generated_ids)[0]

    # # Check if a GPU is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # print(f"Using device: {device}")
    response = response + f". Using device: {device}"
    
    return response

rag_application = gr.Interface(
    fn=retriever_qa,
    allow_flagging="never",
    inputs=[
        # gr.File(label="Upload PDF File", file_count="single", file_types=['.pdf'], type="filepath"),  # Drag and drop file upload
        gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"),  # Drag and drop file upload
        gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...")
    ],
    outputs=gr.Textbox(label="Output"),
    title="RAG Chatbot",
    description="Upload a PDF document and ask any question. The chatbot will try to answer using the provided document. Using device: {device}"
)

rag_application.launch(share=True)