import os # import torch import spaces import gradio as gr from huggingface_hub import login from transformers import AutoModelForCausalLM, AutoTokenizer api_token = os.environ.get("HF_API_TOKEN") login(api_token) # You can use this section to suppress warnings generated by your code: # def warn(*args, **kwargs): # pass # import warnings # warnings.warn = warn # warnings.filterwarnings('ignore') def get_llm(model_id): model = AutoModelForCausalLM.from_pretrained(model_id) model.to('cuda') return model @spaces.GPU def retriever_qa(file, query): model_id = 'mistralai/Mistral-7B-Instruct-v0.2' tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) llm = get_llm(model_id) # retriever_obj = retriever(file) # qa = RetrievalQA.from_chain_type(llm=llm, # chain_type="stuff", # retriever=retriever_obj, # return_source_documents=False) # response = qa.invoke(query) with open(file, 'r') as f: first_line = f.readline() messages = [ {"role": "user", "content": first_line + query} ] print(messages) model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda") print('Start Inference') generated_ids = llm.generate(model_inputs, max_new_tokens=50, do_sample=True) response = generated_ids # print('Start detokenize') # response = tokenizer.batch_decode(generated_ids)[0] # # Check if a GPU is available # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # response = response + f". Using device: {device}" return response rag_application = gr.Interface( fn=retriever_qa, allow_flagging="never", inputs=[ # gr.File(label="Upload PDF File", file_count="single", file_types=['.pdf'], type="filepath"), # Drag and drop file upload gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"), # Drag and drop file upload gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...") ], outputs=gr.Textbox(label="Output"), title="RAG Chatbot", description="Upload a TXT document and ask any question. The chatbot will try to answer using the provided document." ) rag_application.launch(share=True)