Rathapoom commited on
Commit
ffbf6d9
·
verified ·
1 Parent(s): be7ae1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -14
app.py CHANGED
@@ -1,21 +1,70 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import os
 
 
 
4
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- model_name = "scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq"
7
- token = os.getenv("HF_TOKEN")
 
 
 
8
 
9
- # Remove these lines
10
- # device = torch.device("cuda")
11
- # torch.cuda.set_device(0)
 
12
 
13
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
14
- model = AutoModelForCausalLM.from_pretrained(model_name, token=token)
 
15
 
16
- def generate_text(prompt):
17
- inputs = tokenizer(prompt, return_tensors="pt")
18
- outputs = model.generate(inputs.input_ids, max_length=50)
19
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
20
 
21
- gr.Interface(fn=generate_text, inputs="text", outputs="text").launch()
 
 
 
 
1
  import os
2
+ import torch
3
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
+ from PyPDF2 import PdfReader
5
  import gradio as gr
6
+ from datasets import Dataset
7
+
8
+ # Extract text from PDF
9
+ def extract_text_from_pdf(pdf_path):
10
+ text = ""
11
+ with open(pdf_path, "rb") as f:
12
+ reader = PdfReader(f)
13
+ for page in reader.pages:
14
+ text += page.extract_text()
15
+ return text
16
+
17
+ # Load model and tokenizer
18
+ model_name = "scb10x/llama-3-typhoon-v1.5x-8b-instruct"
19
+ tokenizer = RagTokenizer.from_pretrained(model_name)
20
+
21
+ # Extract text from the provided PDF
22
+ pdf_text = extract_text_from_pdf("TOPF 2564.pdf") # Updated path
23
+ passages = [{"title": "", "text": line} for line in pdf_text.split('\n') if line.strip()]
24
+
25
+ # Create a Dataset
26
+ dataset = Dataset.from_list(passages)
27
+
28
+ # Save the dataset and create an index in the current working directory
29
+ dataset_path = "./rag_document_dataset"
30
+ index_path = "./rag_document_index"
31
+
32
+ # Save the dataset to disk and create an index
33
+ dataset.save_to_disk(dataset_path)
34
+ dataset.load_from_disk(dataset_path).add_faiss_index(column="text").save(index_path)
35
+
36
+ # Load the retriever with the custom dataset and index
37
+ retriever = RagRetriever.from_pretrained(
38
+ model_name,
39
+ index_name="custom",
40
+ passages_path=dataset_path,
41
+ index_path=index_path
42
+ )
43
+
44
+ model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
45
 
46
+ # Define the chat function
47
+ def answer_question(question, context):
48
+ inputs = tokenizer(question, return_tensors="pt")
49
+ input_ids = inputs["input_ids"]
50
+ attention_mask = inputs["attention_mask"]
51
 
52
+ # Generate the answer
53
+ outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask)
54
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
+ return answer
56
 
57
+ # Gradio interface setup
58
+ def ask(question):
59
+ return answer_question(question, pdf_text)
60
 
61
+ demo = gr.Interface(
62
+ fn=ask,
63
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Ask something..."),
64
+ outputs="text",
65
+ title="Document QA with RAG",
66
+ description="Ask questions based on the provided document."
67
+ )
68
 
69
+ if __name__ == "__main__":
70
+ demo.launch()