wakeupmh commited on
Commit
8903db2
·
1 Parent(s): f944585

fix: dataframes

Browse files
Files changed (1) hide show
  1. app.py +24 -14
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import streamlit as st
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import os
4
- from datasets import load_from_disk
5
  import torch
6
  import logging
 
7
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.INFO)
@@ -29,7 +30,13 @@ def load_dataset():
29
  import faiss_index.index as idx
30
  papers = idx.fetch_arxiv_papers("autism research", max_results=100)
31
  idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
32
- return load_from_disk(DATASET_PATH)
 
 
 
 
 
 
33
 
34
  def generate_answer(question, context, max_length=200):
35
  tokenizer, model = load_models()
@@ -46,12 +53,16 @@ def generate_answer(question, context, max_length=200):
46
 
47
  # Get model predictions
48
  with torch.no_grad():
49
- outputs = model(**inputs)
50
- answer_ids = torch.argmax(outputs.logits, dim=-1)
51
-
52
- # Convert token positions to text
53
- answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True)
54
-
 
 
 
 
55
  return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
56
 
57
  # Streamlit App
@@ -61,12 +72,11 @@ query = st.text_input("Please ask me anything about autism ✨")
61
  if query:
62
  with st.status("Searching for answers..."):
63
  # Load dataset
64
- dataset = load_dataset()
65
 
66
  # Get relevant context
67
  context = "\n".join([
68
- f"{paper['text'][:1000]}" # Use more context for better answers
69
- for paper in dataset[:3]
70
  ])
71
 
72
  # Generate answer
@@ -77,9 +87,9 @@ if query:
77
  st.write(answer)
78
 
79
  st.write("### Sources Used:")
80
- for i in range(min(3, len(dataset))):
81
- st.write(f"**Title:** {dataset[i]['title']}")
82
- st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
83
  st.write("---")
84
  else:
85
  st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
 
1
  import streamlit as st
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import os
4
+ from datasets import load_from_disk, Dataset
5
  import torch
6
  import logging
7
+ import pandas as pd
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
 
30
  import faiss_index.index as idx
31
  papers = idx.fetch_arxiv_papers("autism research", max_results=100)
32
  idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
33
+
34
+ # Load and convert to pandas for easier handling
35
+ dataset = load_from_disk(DATASET_PATH)
36
+ return pd.DataFrame({
37
+ 'title': dataset['title'],
38
+ 'text': dataset['text']
39
+ })
40
 
41
  def generate_answer(question, context, max_length=200):
42
  tokenizer, model = load_models()
 
53
 
54
  # Get model predictions
55
  with torch.no_grad():
56
+ outputs = model.generate(
57
+ inputs["input_ids"],
58
+ max_length=max_length,
59
+ min_length=30,
60
+ num_beams=4,
61
+ length_penalty=2.0,
62
+ early_stopping=True
63
+ )
64
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
+
66
  return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
67
 
68
  # Streamlit App
 
72
  if query:
73
  with st.status("Searching for answers..."):
74
  # Load dataset
75
+ df = load_dataset()
76
 
77
  # Get relevant context
78
  context = "\n".join([
79
+ f"{text[:1000]}" for text in df['text'].head(3)
 
80
  ])
81
 
82
  # Generate answer
 
87
  st.write(answer)
88
 
89
  st.write("### Sources Used:")
90
+ for _, row in df.head(3).iterrows():
91
+ st.write(f"**Title:** {row['title']}")
92
+ st.write(f"**Summary:** {row['text'][:200]}...")
93
  st.write("---")
94
  else:
95
  st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")