Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
-
from transformers import
|
4 |
from PyPDF2 import PdfReader
|
5 |
import gradio as gr
|
6 |
-
from datasets import Dataset
|
|
|
7 |
|
8 |
# Extract text from PDF
|
9 |
def extract_text_from_pdf(pdf_path):
|
@@ -16,7 +17,8 @@ def extract_text_from_pdf(pdf_path):
|
|
16 |
|
17 |
# Load model and tokenizer
|
18 |
model_name = "scb10x/llama-3-typhoon-v1.5x-8b-instruct"
|
19 |
-
tokenizer =
|
|
|
20 |
|
21 |
# Extract text from the provided PDF
|
22 |
pdf_text = extract_text_from_pdf("TOPF 2564.pdf") # Updated path
|
@@ -29,23 +31,27 @@ dataset = Dataset.from_list(passages)
|
|
29 |
dataset_path = "./rag_document_dataset"
|
30 |
index_path = "./rag_document_index"
|
31 |
|
|
|
|
|
|
|
|
|
32 |
# Save the dataset to disk and create an index
|
33 |
dataset.save_to_disk(dataset_path)
|
34 |
dataset.load_from_disk(dataset_path).add_faiss_index(column="text").save(index_path)
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
)
|
43 |
-
|
44 |
-
model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
|
45 |
|
46 |
# Define the chat function
|
47 |
def answer_question(question, context):
|
48 |
-
|
|
|
49 |
input_ids = inputs["input_ids"]
|
50 |
attention_mask = inputs["attention_mask"]
|
51 |
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
from PyPDF2 import PdfReader
|
5 |
import gradio as gr
|
6 |
+
from datasets import Dataset, load_from_disk, save_to_disk
|
7 |
+
import faiss
|
8 |
|
9 |
# Extract text from PDF
|
10 |
def extract_text_from_pdf(pdf_path):
|
|
|
17 |
|
18 |
# Load model and tokenizer
|
19 |
model_name = "scb10x/llama-3-typhoon-v1.5x-8b-instruct"
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
21 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
22 |
|
23 |
# Extract text from the provided PDF
|
24 |
pdf_text = extract_text_from_pdf("TOPF 2564.pdf") # Updated path
|
|
|
31 |
dataset_path = "./rag_document_dataset"
|
32 |
index_path = "./rag_document_index"
|
33 |
|
34 |
+
# Ensure the directory exists
|
35 |
+
os.makedirs(dataset_path, exist_ok=True)
|
36 |
+
os.makedirs(index_path, exist_ok=True)
|
37 |
+
|
38 |
# Save the dataset to disk and create an index
|
39 |
dataset.save_to_disk(dataset_path)
|
40 |
dataset.load_from_disk(dataset_path).add_faiss_index(column="text").save(index_path)
|
41 |
|
42 |
+
# Custom retriever
|
43 |
+
def retrieve(query):
|
44 |
+
# Use FAISS index to retrieve relevant passages
|
45 |
+
query_embedding = tokenizer(query, return_tensors="pt")["input_ids"]
|
46 |
+
# Perform retrieval (this is a placeholder, actual retrieval code will be more complex)
|
47 |
+
# retrieved_passages = faiss_search(query_embedding)
|
48 |
+
retrieved_passages = " ".join([passage['text'] for passage in passages]) # Simplified for demo
|
49 |
+
return retrieved_passages
|
|
|
50 |
|
51 |
# Define the chat function
|
52 |
def answer_question(question, context):
|
53 |
+
retrieved_context = retrieve(question)
|
54 |
+
inputs = tokenizer(question + " " + retrieved_context, return_tensors="pt")
|
55 |
input_ids = inputs["input_ids"]
|
56 |
attention_mask = inputs["attention_mask"]
|
57 |
|