tonyhui2234 commited on
Commit
833a88b
·
verified ·
1 Parent(s): b1e9f49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -38
app.py CHANGED
@@ -1,24 +1,22 @@
1
- import streamlit as st
2
- import random
3
- import pandas as pd
4
- import requests
5
- from io import BytesIO
6
- from PIL import Image
7
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
8
- import re
 
 
 
 
9
 
10
  # Define maximum dimensions for the fortune image (in pixels)
11
  MAX_SIZE = (400, 400)
12
 
13
- # Initialize button click count in session state
14
  if "button_count_temp" not in st.session_state:
15
  st.session_state.button_count_temp = 0
16
-
17
- # Set page configuration
18
- st.set_page_config(page_title="Fortuen Stick Enquiry", layout="wide")
19
- st.title("Fortuen Stick Enquiry")
20
-
21
- # Initialize session state variables
22
  if "submitted_text" not in st.session_state:
23
  st.session_state.submitted_text = False
24
  if "fortune_number" not in st.session_state:
@@ -32,6 +30,7 @@ if "cfu_explain_text" not in st.session_state:
32
  if "stick_clicked" not in st.session_state:
33
  st.session_state.stick_clicked = False
34
 
 
35
  if "fortune_data" not in st.session_state:
36
  try:
37
  st.session_state.fortune_data = pd.read_csv("/home/user/app/resources/detail.csv")
@@ -39,21 +38,69 @@ if "fortune_data" not in st.session_state:
39
  st.error(f"Error loading CSV: {e}")
40
  st.session_state.fortune_data = None
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def load_finetuned_classifier_model(question):
 
 
 
 
43
  label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"]
44
- # Create a mapping dictionary to convert the default "LABEL_x" output.
45
  mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)}
46
-
47
- pipe = pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10")
48
- prediction = pipe(question)[0]['label']
49
  predicted_label = mapping.get(prediction, prediction)
50
  print(predicted_label)
51
  return predicted_label
52
 
53
- # Define your inference function
54
  def generate_answer(question, fortune):
55
- tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen")
56
- model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen")
 
 
 
57
  input_text = "Question: " + question + " Fortune: " + fortune
58
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
59
  outputs = model.generate(
@@ -68,27 +115,38 @@ def generate_answer(question, fortune):
68
  return answer
69
 
70
  def analysis(row_detail, classifiy, question):
71
- # Use the classifier's output (e.g. "Personal Well-Being") in the regex.
 
 
 
72
  pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE)
73
  match = pattern.search(row_detail)
74
  if match:
75
  result = match.group(1)
76
- # If you want to generate a custom answer, you can call generate_answer()
77
  return generate_answer(question, result)
78
  else:
79
  return "Heaven's secret cannot be revealed."
80
 
81
  def check_sentence_is_english_model(question):
82
- pipe_english = pipeline("text-classification", model="papluca/xlm-roberta-base-language-detection")
 
 
 
83
  return pipe_english(question)[0]['label'] == 'en'
84
 
85
  def check_sentence_is_question_model(question):
86
- pipe_question = pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier")
 
 
 
87
  return pipe_question(question)[0]['label'] == 'LABEL_1'
88
 
89
  def submit_text_callback():
 
 
 
 
90
  question = st.session_state.get("user_sentence", "")
91
- # Clear any previous error message
92
  st.session_state.error_message = ""
93
 
94
  if not check_sentence_is_english_model(question):
@@ -107,12 +165,12 @@ def submit_text_callback():
107
  return
108
 
109
  st.session_state.submitted_text = True
110
- st.session_state.button_count_temp = 0 # Reset the counter once submission is accepted
111
 
112
- # Randomly generate a number from 1 to 100
113
  st.session_state.fortune_number = random.randint(1, 100)
114
 
115
- # Look up the row in the CSV where CNumber matches the generated fortune number.
116
  df = st.session_state.fortune_data
117
  row_detail = ''
118
  if df is not None:
@@ -138,6 +196,9 @@ def submit_text_callback():
138
  print(row_detail)
139
 
140
  def load_and_resize_image(path, max_size=MAX_SIZE):
 
 
 
141
  try:
142
  img = Image.open(path)
143
  img.thumbnail(max_size, Image.Resampling.LANCZOS)
@@ -147,6 +208,9 @@ def load_and_resize_image(path, max_size=MAX_SIZE):
147
  return None
148
 
149
  def download_and_resize_image(url, max_size=MAX_SIZE):
 
 
 
150
  try:
151
  response = requests.get(url)
152
  response.raise_for_status()
@@ -159,24 +223,32 @@ def download_and_resize_image(url, max_size=MAX_SIZE):
159
  return None
160
 
161
  def stick_enquiry_callback():
162
- # Retrieve the user's question and the fortune detail
 
 
 
163
  question = st.session_state.get("user_sentence", "")
164
  if not st.session_state.fortune_row:
165
  st.error("Fortune data is not available. Please submit your question first.")
166
  return
167
  row_detail = st.session_state.fortune_row.get("Detail", "No detail available.")
168
- # Run the classifier model after the image has loaded
169
  classifiy = load_finetuned_classifier_model(question)
170
- # Generate the explanation using the analysis function
171
  cfu_explain = analysis(row_detail, classifiy, question)
172
- # Save the returned value in session state for later display
173
  st.session_state.cfu_explain_text = cfu_explain
174
  st.session_state.stick_clicked = True
175
 
176
- # Main layout: Left (input) and Right (fortune display)
 
 
 
 
 
 
 
 
177
  left_col, _, right_col = st.columns([3, 1, 5])
178
 
179
- # ---- Left Column ----
180
  with left_col:
181
  left_top = st.container()
182
  left_bottom = st.container()
@@ -187,16 +259,17 @@ with left_col:
187
  st.error(st.session_state.error_message)
188
  if st.session_state.submitted_text:
189
  with left_bottom:
 
190
  for _ in range(5):
191
  st.write("")
192
  col1, col2, col3 = st.columns(3)
193
  with col2:
194
  st.button("Cfu Explain", key="stick_button", on_click=stick_enquiry_callback)
195
  if st.session_state.stick_clicked:
196
- # Display the explanation text saved from analysis()
197
  st.text_area(' ', value=st.session_state.cfu_explain_text, height=300, disabled=True)
198
 
199
- # ---- Right Column ----
200
  with right_col:
201
  with st.container():
202
  col_left, col_center, col_right = st.columns([1, 2, 1])
 
1
+ import streamlit as st # For creating the web app interface
2
+ import random # For generating random fortune numbers
3
+ import pandas as pd # For handling CSV data
4
+ import requests # For downloading images from URLs
5
+ from io import BytesIO # For handling image bytes
6
+ from PIL import Image # For image processing
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM # For NLP models
8
+ import re # For regex operations
9
+
10
+ # This script implements a Fortune Stick Enquiry app.
11
+ # Users enter a question, which is validated and processed.
12
+ # A random fortune is chosen from a CSV, and NLP models classify and generate custom answers.
13
 
14
  # Define maximum dimensions for the fortune image (in pixels)
15
  MAX_SIZE = (400, 400)
16
 
17
+ # Initialize session state variables for button clicks, fortune details, etc.
18
  if "button_count_temp" not in st.session_state:
19
  st.session_state.button_count_temp = 0
 
 
 
 
 
 
20
  if "submitted_text" not in st.session_state:
21
  st.session_state.submitted_text = False
22
  if "fortune_number" not in st.session_state:
 
30
  if "stick_clicked" not in st.session_state:
31
  st.session_state.stick_clicked = False
32
 
33
+ # Load fortune data from CSV file
34
  if "fortune_data" not in st.session_state:
35
  try:
36
  st.session_state.fortune_data = pd.read_csv("/home/user/app/resources/detail.csv")
 
38
  st.error(f"Error loading CSV: {e}")
39
  st.session_state.fortune_data = None
40
 
41
+ # ----------------------------------------------------
42
+ # CACHED MODEL LOADING FUNCTIONS
43
+ # ----------------------------------------------------
44
+
45
+ @st.cache_resource
46
+ def load_classifier_pipeline():
47
+ """
48
+ Load and cache the finetuned classifier pipeline.
49
+ This model classifies the input question into one of the fortune categories.
50
+ """
51
+ return pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10")
52
+
53
+ @st.cache_resource
54
+ def load_tokenizer_and_model():
55
+ """
56
+ Load and cache the tokenizer and model for generating custom answers.
57
+ Uses a finetuned sequence-to-sequence model from Hugging Face.
58
+ """
59
+ tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen")
60
+ model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen")
61
+ return tokenizer, model
62
+
63
+ @st.cache_resource
64
+ def load_english_detection_pipeline():
65
+ """
66
+ Load and cache the English language detection pipeline.
67
+ This ensures that the user's question is in English.
68
+ """
69
+ return pipeline("text-classification", model="papluca/xlm-roberta-base-language-detection")
70
+
71
+ @st.cache_resource
72
+ def load_question_detection_pipeline():
73
+ """
74
+ Load and cache the question vs. statement detection pipeline.
75
+ This checks if the input text is a question.
76
+ """
77
+ return pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier")
78
+
79
+ # ----------------------------------------------------
80
+ # FUNCTION DEFINITIONS
81
+ # ----------------------------------------------------
82
+
83
  def load_finetuned_classifier_model(question):
84
+ """
85
+ Classify the input question into a specific fortune category.
86
+ Maps the classifier's output label to a human-readable format.
87
+ """
88
  label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"]
89
+ # Mapping dictionary to convert the default "LABEL_x" output.
90
  mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)}
91
+
92
+ classifier_pipe = load_classifier_pipeline()
93
+ prediction = classifier_pipe(question)[0]['label']
94
  predicted_label = mapping.get(prediction, prediction)
95
  print(predicted_label)
96
  return predicted_label
97
 
 
98
  def generate_answer(question, fortune):
99
+ """
100
+ Generate a custom answer using a finetuned sequence-to-sequence model.
101
+ Combines the user's question with the fortune message to produce a response.
102
+ """
103
+ tokenizer, model = load_tokenizer_and_model()
104
  input_text = "Question: " + question + " Fortune: " + fortune
105
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
106
  outputs = model.generate(
 
115
  return answer
116
 
117
  def analysis(row_detail, classifiy, question):
118
+ """
119
+ Analyze the fortune detail based on the classifier's output.
120
+ Extracts the specific fortune message using regex and generates an answer.
121
+ """
122
  pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE)
123
  match = pattern.search(row_detail)
124
  if match:
125
  result = match.group(1)
 
126
  return generate_answer(question, result)
127
  else:
128
  return "Heaven's secret cannot be revealed."
129
 
130
  def check_sentence_is_english_model(question):
131
+ """
132
+ Check if the input question is in English using a language detection model.
133
+ """
134
+ pipe_english = load_english_detection_pipeline()
135
  return pipe_english(question)[0]['label'] == 'en'
136
 
137
  def check_sentence_is_question_model(question):
138
+ """
139
+ Check if the input text is a question using a question vs. statement classifier.
140
+ """
141
+ pipe_question = load_question_detection_pipeline()
142
  return pipe_question(question)[0]['label'] == 'LABEL_1'
143
 
144
  def submit_text_callback():
145
+ """
146
+ Callback function executed when the user submits their question.
147
+ Validates the input and retrieves a corresponding fortune based on a random number.
148
+ """
149
  question = st.session_state.get("user_sentence", "")
 
150
  st.session_state.error_message = ""
151
 
152
  if not check_sentence_is_english_model(question):
 
165
  return
166
 
167
  st.session_state.submitted_text = True
168
+ st.session_state.button_count_temp = 0 # Reset the counter after submission
169
 
170
+ # Randomly generate a fortune stick number between 1 and 100
171
  st.session_state.fortune_number = random.randint(1, 100)
172
 
173
+ # Retrieve fortune details from CSV data
174
  df = st.session_state.fortune_data
175
  row_detail = ''
176
  if df is not None:
 
196
  print(row_detail)
197
 
198
  def load_and_resize_image(path, max_size=MAX_SIZE):
199
+ """
200
+ Load an image from a local path and resize it to fit within MAX_SIZE.
201
+ """
202
  try:
203
  img = Image.open(path)
204
  img.thumbnail(max_size, Image.Resampling.LANCZOS)
 
208
  return None
209
 
210
  def download_and_resize_image(url, max_size=MAX_SIZE):
211
+ """
212
+ Download an image from a URL and resize it to fit within MAX_SIZE.
213
+ """
214
  try:
215
  response = requests.get(url)
216
  response.raise_for_status()
 
223
  return None
224
 
225
  def stick_enquiry_callback():
226
+ """
227
+ Callback function executed when the user clicks "Cfu Explain".
228
+ Uses the classifier to analyze the fortune details and generate a custom answer.
229
+ """
230
  question = st.session_state.get("user_sentence", "")
231
  if not st.session_state.fortune_row:
232
  st.error("Fortune data is not available. Please submit your question first.")
233
  return
234
  row_detail = st.session_state.fortune_row.get("Detail", "No detail available.")
 
235
  classifiy = load_finetuned_classifier_model(question)
 
236
  cfu_explain = analysis(row_detail, classifiy, question)
 
237
  st.session_state.cfu_explain_text = cfu_explain
238
  st.session_state.stick_clicked = True
239
 
240
+ # ----------------------------------------------------
241
+ # STREAMLIT APP LAYOUT
242
+ # ----------------------------------------------------
243
+
244
+ # Set page configuration and title
245
+ st.set_page_config(page_title="Fortuen Stick Enquiry", layout="wide")
246
+ st.title("Fortuen Stick Enquiry")
247
+
248
+ # Define the main layout columns: Left for user input, Right for fortune display
249
  left_col, _, right_col = st.columns([3, 1, 5])
250
 
251
+ # ---- Left Column: User Input and Interaction ----
252
  with left_col:
253
  left_top = st.container()
254
  left_bottom = st.container()
 
259
  st.error(st.session_state.error_message)
260
  if st.session_state.submitted_text:
261
  with left_bottom:
262
+ # Add spacing
263
  for _ in range(5):
264
  st.write("")
265
  col1, col2, col3 = st.columns(3)
266
  with col2:
267
  st.button("Cfu Explain", key="stick_button", on_click=stick_enquiry_callback)
268
  if st.session_state.stick_clicked:
269
+ # Display the generated explanation text
270
  st.text_area(' ', value=st.session_state.cfu_explain_text, height=300, disabled=True)
271
 
272
+ # ---- Right Column: Fortune Display and Images ----
273
  with right_col:
274
  with st.container():
275
  col_left, col_center, col_right = st.columns([1, 2, 1])