Rathapoom commited on
Commit
09ec353
·
verified ·
1 Parent(s): ffbf6d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
  import torch
3
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
  from PyPDF2 import PdfReader
5
  import gradio as gr
6
- from datasets import Dataset
 
7
 
8
  # Extract text from PDF
9
  def extract_text_from_pdf(pdf_path):
@@ -16,7 +17,8 @@ def extract_text_from_pdf(pdf_path):
16
 
17
  # Load model and tokenizer
18
  model_name = "scb10x/llama-3-typhoon-v1.5x-8b-instruct"
19
- tokenizer = RagTokenizer.from_pretrained(model_name)
 
20
 
21
  # Extract text from the provided PDF
22
  pdf_text = extract_text_from_pdf("TOPF 2564.pdf") # Updated path
@@ -29,23 +31,27 @@ dataset = Dataset.from_list(passages)
29
  dataset_path = "./rag_document_dataset"
30
  index_path = "./rag_document_index"
31
 
 
 
 
 
32
  # Save the dataset to disk and create an index
33
  dataset.save_to_disk(dataset_path)
34
  dataset.load_from_disk(dataset_path).add_faiss_index(column="text").save(index_path)
35
 
36
- # Load the retriever with the custom dataset and index
37
- retriever = RagRetriever.from_pretrained(
38
- model_name,
39
- index_name="custom",
40
- passages_path=dataset_path,
41
- index_path=index_path
42
- )
43
-
44
- model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
45
 
46
  # Define the chat function
47
  def answer_question(question, context):
48
- inputs = tokenizer(question, return_tensors="pt")
 
49
  input_ids = inputs["input_ids"]
50
  attention_mask = inputs["attention_mask"]
51
 
 
1
  import os
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from PyPDF2 import PdfReader
5
  import gradio as gr
6
+ from datasets import Dataset, load_from_disk, save_to_disk
7
+ import faiss
8
 
9
  # Extract text from PDF
10
  def extract_text_from_pdf(pdf_path):
 
17
 
18
  # Load model and tokenizer
19
  model_name = "scb10x/llama-3-typhoon-v1.5x-8b-instruct"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForCausalLM.from_pretrained(model_name)
22
 
23
  # Extract text from the provided PDF
24
  pdf_text = extract_text_from_pdf("TOPF 2564.pdf") # Updated path
 
31
  dataset_path = "./rag_document_dataset"
32
  index_path = "./rag_document_index"
33
 
34
+ # Ensure the directory exists
35
+ os.makedirs(dataset_path, exist_ok=True)
36
+ os.makedirs(index_path, exist_ok=True)
37
+
38
  # Save the dataset to disk and create an index
39
  dataset.save_to_disk(dataset_path)
40
  dataset.load_from_disk(dataset_path).add_faiss_index(column="text").save(index_path)
41
 
42
+ # Custom retriever
43
+ def retrieve(query):
44
+ # Use FAISS index to retrieve relevant passages
45
+ query_embedding = tokenizer(query, return_tensors="pt")["input_ids"]
46
+ # Perform retrieval (this is a placeholder, actual retrieval code will be more complex)
47
+ # retrieved_passages = faiss_search(query_embedding)
48
+ retrieved_passages = " ".join([passage['text'] for passage in passages]) # Simplified for demo
49
+ return retrieved_passages
 
50
 
51
  # Define the chat function
52
  def answer_question(question, context):
53
+ retrieved_context = retrieve(question)
54
+ inputs = tokenizer(question + " " + retrieved_context, return_tensors="pt")
55
  input_ids = inputs["input_ids"]
56
  attention_mask = inputs["attention_mask"]
57