wakeupmh commited on
Commit
d3e32db
·
1 Parent(s): 59cbae1

fix: try to lightweight it

Browse files
Files changed (1) hide show
  1. app.py +46 -45
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import streamlit as st
2
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
- import faiss
4
  import os
5
  from datasets import load_from_disk
6
  import torch
@@ -13,51 +12,37 @@ logging.basicConfig(level=logging.INFO)
13
  DATA_DIR = "/data" if os.path.exists("/data") else "."
14
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
15
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
16
- INDEX_PATH = os.path.join(DATASET_DIR, "embeddings.faiss")
17
 
18
  # Cache models and dataset
19
- @st.cache_resource # Cache models in memory
20
  def load_models():
21
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
22
- retriever = RagRetriever.from_pretrained(
23
- "facebook/rag-sequence-nq",
24
- index_name="custom",
25
- passages_path=DATASET_PATH,
26
- index_path=INDEX_PATH
27
- )
28
- model = RagSequenceForGeneration.from_pretrained(
29
- "facebook/rag-sequence-nq",
30
- retriever=retriever
31
- )
32
- # Move to CPU (since we're in a CPU environment)
33
- model = model.cpu()
34
- return tokenizer, retriever, model
35
 
36
- @st.cache_data # Cache dataset on disk
37
- def load_dataset():
38
- # Create initial dataset if it doesn't exist
39
- if not os.path.exists(DATASET_PATH):
40
- with st.spinner("Building initial dataset from autism research papers..."):
41
- import faiss_index.index as idx
42
- papers = idx.fetch_arxiv_papers("autism research", max_results=100)
43
- idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
44
- return load_from_disk(DATASET_PATH)
45
-
46
- # RAG Pipeline
47
- def rag_pipeline(query, dataset, index):
48
- tokenizer, retriever, model = load_models()
49
- inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
50
  with torch.no_grad():
51
- outputs = model.generate(
52
- inputs["input_ids"],
53
- max_length=200,
54
- min_length=50,
55
- num_beams=5,
56
- early_stopping=True,
57
- no_repeat_ngram_size=3
58
- )
59
- answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
60
- return answer
61
 
62
  # Streamlit App
63
  st.title("🧩 AMA Autism")
@@ -65,10 +50,26 @@ query = st.text_input("Please ask me anything about autism ✨")
65
 
66
  if query:
67
  with st.status("Searching for answers..."):
 
68
  dataset = load_dataset()
69
- answer = rag_pipeline(query, dataset, index=None)
70
- if answer:
 
 
 
 
 
 
 
 
 
71
  st.success("Answer found!")
72
  st.write(answer)
 
 
 
 
 
 
73
  else:
74
- st.error("Failed to generate an answer.")
 
1
  import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
3
  import os
4
  from datasets import load_from_disk
5
  import torch
 
12
  DATA_DIR = "/data" if os.path.exists("/data") else "."
13
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
14
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
 
15
 
16
  # Cache models and dataset
17
+ @st.cache_resource
18
  def load_models():
19
+ model_name = "t5-base"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
+ return tokenizer, model
 
 
 
 
 
 
 
 
 
 
23
 
24
+ def generate_answer(question, context, max_length=200):
25
+ tokenizer, model = load_models()
26
+
27
+ # Encode the question and context
28
+ inputs = tokenizer(
29
+ f"question: {question} context: {context}",
30
+ add_special_tokens=True,
31
+ return_tensors="pt",
32
+ max_length=512,
33
+ truncation=True,
34
+ padding=True
35
+ )
36
+
37
+ # Get model predictions
38
  with torch.no_grad():
39
+ outputs = model(**inputs)
40
+ answer_ids = torch.argmax(outputs.logits, dim=-1)
41
+
42
+ # Convert token positions to text
43
+ answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True)
44
+
45
+ return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
 
 
 
46
 
47
  # Streamlit App
48
  st.title("🧩 AMA Autism")
 
50
 
51
  if query:
52
  with st.status("Searching for answers..."):
53
+ # Load dataset
54
  dataset = load_dataset()
55
+
56
+ # Get relevant context
57
+ context = "\n".join([
58
+ f"{paper['text'][:1000]}" # Use more context for better answers
59
+ for paper in dataset[:3]
60
+ ])
61
+
62
+ # Generate answer
63
+ answer = generate_answer(query, context)
64
+
65
+ if answer and not answer.isspace():
66
  st.success("Answer found!")
67
  st.write(answer)
68
+
69
+ st.write("### Sources Used:")
70
+ for i in range(min(3, len(dataset))):
71
+ st.write(f"**Title:** {dataset[i]['title']}")
72
+ st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
73
+ st.write("---")
74
  else:
75
+ st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")