tonyhui2234 commited on
Commit
424c696
·
verified ·
1 Parent(s): 833a88b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -111
app.py CHANGED
@@ -1,22 +1,24 @@
1
- import streamlit as st # For creating the web app interface
2
- import random # For generating random fortune numbers
3
- import pandas as pd # For handling CSV data
4
- import requests # For downloading images from URLs
5
- from io import BytesIO # For handling image bytes
6
- from PIL import Image # For image processing
7
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM # For NLP models
8
- import re # For regex operations
9
-
10
- # This script implements a Fortune Stick Enquiry app.
11
- # Users enter a question, which is validated and processed.
12
- # A random fortune is chosen from a CSV, and NLP models classify and generate custom answers.
13
 
14
  # Define maximum dimensions for the fortune image (in pixels)
15
  MAX_SIZE = (400, 400)
16
 
17
- # Initialize session state variables for button clicks, fortune details, etc.
18
  if "button_count_temp" not in st.session_state:
19
  st.session_state.button_count_temp = 0
 
 
 
 
 
 
20
  if "submitted_text" not in st.session_state:
21
  st.session_state.submitted_text = False
22
  if "fortune_number" not in st.session_state:
@@ -30,7 +32,6 @@ if "cfu_explain_text" not in st.session_state:
30
  if "stick_clicked" not in st.session_state:
31
  st.session_state.stick_clicked = False
32
 
33
- # Load fortune data from CSV file
34
  if "fortune_data" not in st.session_state:
35
  try:
36
  st.session_state.fortune_data = pd.read_csv("/home/user/app/resources/detail.csv")
@@ -38,69 +39,25 @@ if "fortune_data" not in st.session_state:
38
  st.error(f"Error loading CSV: {e}")
39
  st.session_state.fortune_data = None
40
 
41
- # ----------------------------------------------------
42
- # CACHED MODEL LOADING FUNCTIONS
43
- # ----------------------------------------------------
44
-
45
- @st.cache_resource
46
- def load_classifier_pipeline():
47
- """
48
- Load and cache the finetuned classifier pipeline.
49
- This model classifies the input question into one of the fortune categories.
50
- """
51
- return pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10")
52
-
53
- @st.cache_resource
54
- def load_tokenizer_and_model():
55
- """
56
- Load and cache the tokenizer and model for generating custom answers.
57
- Uses a finetuned sequence-to-sequence model from Hugging Face.
58
- """
59
- tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen")
60
- model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen")
61
- return tokenizer, model
62
-
63
- @st.cache_resource
64
- def load_english_detection_pipeline():
65
- """
66
- Load and cache the English language detection pipeline.
67
- This ensures that the user's question is in English.
68
- """
69
- return pipeline("text-classification", model="papluca/xlm-roberta-base-language-detection")
70
-
71
- @st.cache_resource
72
- def load_question_detection_pipeline():
73
- """
74
- Load and cache the question vs. statement detection pipeline.
75
- This checks if the input text is a question.
76
- """
77
- return pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier")
78
-
79
- # ----------------------------------------------------
80
- # FUNCTION DEFINITIONS
81
- # ----------------------------------------------------
82
-
83
  def load_finetuned_classifier_model(question):
84
- """
85
- Classify the input question into a specific fortune category.
86
- Maps the classifier's output label to a human-readable format.
87
- """
88
  label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"]
89
- # Mapping dictionary to convert the default "LABEL_x" output.
90
  mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)}
91
-
92
- classifier_pipe = load_classifier_pipeline()
93
- prediction = classifier_pipe(question)[0]['label']
94
  predicted_label = mapping.get(prediction, prediction)
95
  print(predicted_label)
96
  return predicted_label
97
 
 
 
 
 
 
 
98
  def generate_answer(question, fortune):
99
- """
100
- Generate a custom answer using a finetuned sequence-to-sequence model.
101
- Combines the user's question with the fortune message to produce a response.
102
- """
103
- tokenizer, model = load_tokenizer_and_model()
104
  input_text = "Question: " + question + " Fortune: " + fortune
105
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
106
  outputs = model.generate(
@@ -115,38 +72,27 @@ def generate_answer(question, fortune):
115
  return answer
116
 
117
  def analysis(row_detail, classifiy, question):
118
- """
119
- Analyze the fortune detail based on the classifier's output.
120
- Extracts the specific fortune message using regex and generates an answer.
121
- """
122
  pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE)
123
  match = pattern.search(row_detail)
124
  if match:
125
  result = match.group(1)
 
126
  return generate_answer(question, result)
127
  else:
128
  return "Heaven's secret cannot be revealed."
129
 
130
  def check_sentence_is_english_model(question):
131
- """
132
- Check if the input question is in English using a language detection model.
133
- """
134
- pipe_english = load_english_detection_pipeline()
135
  return pipe_english(question)[0]['label'] == 'en'
136
 
137
  def check_sentence_is_question_model(question):
138
- """
139
- Check if the input text is a question using a question vs. statement classifier.
140
- """
141
- pipe_question = load_question_detection_pipeline()
142
  return pipe_question(question)[0]['label'] == 'LABEL_1'
143
 
144
  def submit_text_callback():
145
- """
146
- Callback function executed when the user submits their question.
147
- Validates the input and retrieves a corresponding fortune based on a random number.
148
- """
149
  question = st.session_state.get("user_sentence", "")
 
150
  st.session_state.error_message = ""
151
 
152
  if not check_sentence_is_english_model(question):
@@ -165,12 +111,12 @@ def submit_text_callback():
165
  return
166
 
167
  st.session_state.submitted_text = True
168
- st.session_state.button_count_temp = 0 # Reset the counter after submission
169
 
170
- # Randomly generate a fortune stick number between 1 and 100
171
  st.session_state.fortune_number = random.randint(1, 100)
172
 
173
- # Retrieve fortune details from CSV data
174
  df = st.session_state.fortune_data
175
  row_detail = ''
176
  if df is not None:
@@ -196,9 +142,6 @@ def submit_text_callback():
196
  print(row_detail)
197
 
198
  def load_and_resize_image(path, max_size=MAX_SIZE):
199
- """
200
- Load an image from a local path and resize it to fit within MAX_SIZE.
201
- """
202
  try:
203
  img = Image.open(path)
204
  img.thumbnail(max_size, Image.Resampling.LANCZOS)
@@ -208,9 +151,6 @@ def load_and_resize_image(path, max_size=MAX_SIZE):
208
  return None
209
 
210
  def download_and_resize_image(url, max_size=MAX_SIZE):
211
- """
212
- Download an image from a URL and resize it to fit within MAX_SIZE.
213
- """
214
  try:
215
  response = requests.get(url)
216
  response.raise_for_status()
@@ -223,32 +163,24 @@ def download_and_resize_image(url, max_size=MAX_SIZE):
223
  return None
224
 
225
  def stick_enquiry_callback():
226
- """
227
- Callback function executed when the user clicks "Cfu Explain".
228
- Uses the classifier to analyze the fortune details and generate a custom answer.
229
- """
230
  question = st.session_state.get("user_sentence", "")
231
  if not st.session_state.fortune_row:
232
  st.error("Fortune data is not available. Please submit your question first.")
233
  return
234
  row_detail = st.session_state.fortune_row.get("Detail", "No detail available.")
 
235
  classifiy = load_finetuned_classifier_model(question)
 
236
  cfu_explain = analysis(row_detail, classifiy, question)
 
237
  st.session_state.cfu_explain_text = cfu_explain
238
  st.session_state.stick_clicked = True
239
 
240
- # ----------------------------------------------------
241
- # STREAMLIT APP LAYOUT
242
- # ----------------------------------------------------
243
-
244
- # Set page configuration and title
245
- st.set_page_config(page_title="Fortuen Stick Enquiry", layout="wide")
246
- st.title("Fortuen Stick Enquiry")
247
-
248
- # Define the main layout columns: Left for user input, Right for fortune display
249
  left_col, _, right_col = st.columns([3, 1, 5])
250
 
251
- # ---- Left Column: User Input and Interaction ----
252
  with left_col:
253
  left_top = st.container()
254
  left_bottom = st.container()
@@ -259,17 +191,16 @@ with left_col:
259
  st.error(st.session_state.error_message)
260
  if st.session_state.submitted_text:
261
  with left_bottom:
262
- # Add spacing
263
  for _ in range(5):
264
  st.write("")
265
  col1, col2, col3 = st.columns(3)
266
  with col2:
267
  st.button("Cfu Explain", key="stick_button", on_click=stick_enquiry_callback)
268
  if st.session_state.stick_clicked:
269
- # Display the generated explanation text
270
  st.text_area(' ', value=st.session_state.cfu_explain_text, height=300, disabled=True)
271
 
272
- # ---- Right Column: Fortune Display and Images ----
273
  with right_col:
274
  with st.container():
275
  col_left, col_center, col_right = st.columns([1, 2, 1])
@@ -308,3 +239,21 @@ with right_col:
308
 
309
  st.text_area("Description", value=description_text, height=150, disabled=True)
310
  st.text_area("Detail", value=detail_text, height=150, disabled=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ import pandas as pd
4
+ import requests
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
8
+ import re
 
 
 
 
9
 
10
  # Define maximum dimensions for the fortune image (in pixels)
11
  MAX_SIZE = (400, 400)
12
 
13
+ # Initialize button click count in session state
14
  if "button_count_temp" not in st.session_state:
15
  st.session_state.button_count_temp = 0
16
+
17
+ # Set page configuration
18
+ st.set_page_config(page_title="Fortuen Stick Enquiry", layout="wide")
19
+ st.title("Fortuen Stick Enquiry")
20
+
21
+ # Initialize session state variables
22
  if "submitted_text" not in st.session_state:
23
  st.session_state.submitted_text = False
24
  if "fortune_number" not in st.session_state:
 
32
  if "stick_clicked" not in st.session_state:
33
  st.session_state.stick_clicked = False
34
 
 
35
  if "fortune_data" not in st.session_state:
36
  try:
37
  st.session_state.fortune_data = pd.read_csv("/home/user/app/resources/detail.csv")
 
39
  st.error(f"Error loading CSV: {e}")
40
  st.session_state.fortune_data = None
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def load_finetuned_classifier_model(question):
 
 
 
 
43
  label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"]
44
+ # Create a mapping dictionary to convert the default "LABEL_x" output.
45
  mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)}
46
+
47
+ pipe = pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10")
48
+ prediction = pipe(question)[0]['label']
49
  predicted_label = mapping.get(prediction, prediction)
50
  print(predicted_label)
51
  return predicted_label
52
 
53
+ @st.cache_resource
54
+ def load_model_and_tokenizer():
55
+ tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen")
56
+ model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen")
57
+ return tokenizer, model
58
+
59
  def generate_answer(question, fortune):
60
+ tokenizer, model = load_model_and_tokenizer()
 
 
 
 
61
  input_text = "Question: " + question + " Fortune: " + fortune
62
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
63
  outputs = model.generate(
 
72
  return answer
73
 
74
  def analysis(row_detail, classifiy, question):
75
+ # Use the classifier's output (e.g. "Personal Well-Being") in the regex.
 
 
 
76
  pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE)
77
  match = pattern.search(row_detail)
78
  if match:
79
  result = match.group(1)
80
+ # If you want to generate a custom answer, you can call generate_answer()
81
  return generate_answer(question, result)
82
  else:
83
  return "Heaven's secret cannot be revealed."
84
 
85
  def check_sentence_is_english_model(question):
86
+ pipe_english = pipeline("text-classification", model="papluca/xlm-roberta-base-language-detection")
 
 
 
87
  return pipe_english(question)[0]['label'] == 'en'
88
 
89
  def check_sentence_is_question_model(question):
90
+ pipe_question = pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier")
 
 
 
91
  return pipe_question(question)[0]['label'] == 'LABEL_1'
92
 
93
  def submit_text_callback():
 
 
 
 
94
  question = st.session_state.get("user_sentence", "")
95
+ # Clear any previous error message
96
  st.session_state.error_message = ""
97
 
98
  if not check_sentence_is_english_model(question):
 
111
  return
112
 
113
  st.session_state.submitted_text = True
114
+ st.session_state.button_count_temp = 0 # Reset the counter once submission is accepted
115
 
116
+ # Randomly generate a number from 1 to 100
117
  st.session_state.fortune_number = random.randint(1, 100)
118
 
119
+ # Look up the row in the CSV where CNumber matches the generated fortune number.
120
  df = st.session_state.fortune_data
121
  row_detail = ''
122
  if df is not None:
 
142
  print(row_detail)
143
 
144
  def load_and_resize_image(path, max_size=MAX_SIZE):
 
 
 
145
  try:
146
  img = Image.open(path)
147
  img.thumbnail(max_size, Image.Resampling.LANCZOS)
 
151
  return None
152
 
153
  def download_and_resize_image(url, max_size=MAX_SIZE):
 
 
 
154
  try:
155
  response = requests.get(url)
156
  response.raise_for_status()
 
163
  return None
164
 
165
  def stick_enquiry_callback():
166
+ # Retrieve the user's question and the fortune detail
 
 
 
167
  question = st.session_state.get("user_sentence", "")
168
  if not st.session_state.fortune_row:
169
  st.error("Fortune data is not available. Please submit your question first.")
170
  return
171
  row_detail = st.session_state.fortune_row.get("Detail", "No detail available.")
172
+ # Run the classifier model after the image has loaded
173
  classifiy = load_finetuned_classifier_model(question)
174
+ # Generate the explanation using the analysis function
175
  cfu_explain = analysis(row_detail, classifiy, question)
176
+ # Save the returned value in session state for later display
177
  st.session_state.cfu_explain_text = cfu_explain
178
  st.session_state.stick_clicked = True
179
 
180
+ # Main layout: Left (input) and Right (fortune display)
 
 
 
 
 
 
 
 
181
  left_col, _, right_col = st.columns([3, 1, 5])
182
 
183
+ # ---- Left Column ----
184
  with left_col:
185
  left_top = st.container()
186
  left_bottom = st.container()
 
191
  st.error(st.session_state.error_message)
192
  if st.session_state.submitted_text:
193
  with left_bottom:
 
194
  for _ in range(5):
195
  st.write("")
196
  col1, col2, col3 = st.columns(3)
197
  with col2:
198
  st.button("Cfu Explain", key="stick_button", on_click=stick_enquiry_callback)
199
  if st.session_state.stick_clicked:
200
+ # Display the explanation text saved from analysis()
201
  st.text_area(' ', value=st.session_state.cfu_explain_text, height=300, disabled=True)
202
 
203
+ # ---- Right Column ----
204
  with right_col:
205
  with st.container():
206
  col_left, col_center, col_right = st.columns([1, 2, 1])
 
239
 
240
  st.text_area("Description", value=description_text, height=150, disabled=True)
241
  st.text_area("Detail", value=detail_text, height=150, disabled=True)
242
+
243
+ why when loading the function
244
+ # Define your inference function
245
+ def generate_answer(question, fortune):
246
+ tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen")
247
+ model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen")
248
+ input_text = "Question: " + question + " Fortune: " + fortune
249
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
250
+ outputs = model.generate(
251
+ **inputs,
252
+ max_length=256,
253
+ num_beams=4,
254
+ early_stopping=True,
255
+ repetition_penalty=2.0,
256
+ no_repeat_ngram_size=3
257
+ )
258
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
259
+ return answer