tonyhui2234 commited on
Commit
542042e
·
verified ·
1 Parent(s): 6bbfa6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -9
app.py CHANGED
@@ -1,19 +1,226 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
 
 
 
 
2
 
3
- # Load the saved model and tokenizer
4
- tokenizer = AutoTokenizer.from_pretrained("/home/user/app/my_finetuned_model_2/")
5
- model = AutoModelForSeq2SeqLM.from_pretrained("/home/user/app/my_finetuned_model_2/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Define your inference function
8
  def generate_answer(question, fortune):
 
 
 
 
9
  input_text = "Question: " + question + " Fortune: " + fortune
10
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
11
  outputs = model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True, repetition_penalty=2.0, no_repeat_ngram_size=3)
12
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
  return answer
14
 
15
- # Test the model with a sample input
16
- sample_question = "Should I start my own business now?"
17
- sample_fortune = "absence of rain causes worry."
18
- print("Generated Answer:")
19
- print(generate_answer(sample_question, sample_fortune))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ import pandas as pd
4
+ import requests
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
8
+ import re
9
 
10
+ # Define maximum dimensions for the fortune image (in pixels)
11
+ MAX_SIZE = (400, 400)
12
+
13
+ # Initialize button click count in session state
14
+ if "button_count_temp" not in st.session_state:
15
+ st.session_state.button_count_temp = 0
16
+
17
+ # Set page configuration
18
+ st.set_page_config(page_title="Fortuen Stick Enquiry", layout="wide")
19
+ st.title("Fortuen Stick Enquiry")
20
+
21
+ # Initialize session state variables
22
+ if "submitted_text" not in st.session_state:
23
+ st.session_state.submitted_text = False
24
+ if "fortune_number" not in st.session_state:
25
+ st.session_state.fortune_number = None
26
+ if "fortune_row" not in st.session_state:
27
+ st.session_state.fortune_row = None
28
+ if "error_message" not in st.session_state:
29
+ st.session_state.error_message = ""
30
+ if "cfu_explain_text" not in st.session_state:
31
+ st.session_state.cfu_explain_text = ""
32
+
33
+ if "fortune_data" not in st.session_state:
34
+ try:
35
+ st.session_state.fortune_data = pd.read_csv("detail.csv")
36
+ except Exception as e:
37
+ st.error(f"Error loading CSV: {e}")
38
+ st.session_state.fortune_data = None
39
+
40
+ if "stick_clicked" not in st.session_state:
41
+ st.session_state.stick_clicked = False
42
+
43
+ def load_finetuned_classifier_model(question):
44
+ label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"]
45
+ # Create a mapping dictionary to convert the default "LABEL_x" output.
46
+ mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)}
47
+
48
+ pipe = pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10")
49
+ prediction = pipe(question)[0]['label']
50
+ predicted_label = mapping.get(prediction, prediction)
51
+ print(predicted_label)
52
+ return predicted_label
53
 
54
  # Define your inference function
55
  def generate_answer(question, fortune):
56
+ # Load the saved model and tokenizer
57
+ tokenizer = AutoTokenizer.from_pretrained("/home/user/app/my_finetuned_model_2/")
58
+ model = AutoModelForSeq2SeqLM.from_pretrained("/home/user/app/my_finetuned_model_2/")
59
+
60
  input_text = "Question: " + question + " Fortune: " + fortune
61
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
62
  outputs = model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True, repetition_penalty=2.0, no_repeat_ngram_size=3)
63
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
  return answer
65
 
66
+ def analysis(row_detail, classifiy, question):
67
+ # Use the classifier's output (e.g. "Personal Well-Being") in the regex.
68
+ pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE)
69
+ match = pattern.search(row_detail)
70
+ if match:
71
+ result = match.group(1)
72
+ # If you want to generate a custom answer, you can call generate_answer()
73
+ return generate_answer(question, result)
74
+ # return result
75
+ else:
76
+ return "Heaven's secret cannot be revealed."
77
+
78
+ def check_sentence_is_english_model(question):
79
+ pipe_english = pipeline("text-classification", model="papluca/xlm-roberta-base-language-detection")
80
+ return pipe_english(question)[0]['label'] == 'en'
81
+
82
+ def check_sentence_is_question_model(question):
83
+ pipe_question = pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier")
84
+ return pipe_question(question)[0]['label'] == 'LABEL_1'
85
+
86
+ def submit_text_callback():
87
+ question = st.session_state.get("user_sentence", "")
88
+ # Clear any previous error message
89
+ st.session_state.error_message = ""
90
+
91
+ if not check_sentence_is_english_model(question):
92
+ st.session_state.error_message = "Please enter in English!"
93
+ st.session_state.button_count_temp = 0
94
+ return
95
+
96
+ if not check_sentence_is_question_model(question):
97
+ st.session_state.error_message = "This is not a question. Please enter again!"
98
+ st.session_state.button_count_temp = 0
99
+ return
100
+
101
+ if st.session_state.button_count_temp == 0:
102
+ st.session_state.error_message = "Please take a moment to quietly reflect on your question in your mind, then click submit."
103
+ st.session_state.button_count_temp = 1
104
+ return
105
+
106
+ st.session_state.submitted_text = True
107
+ st.session_state.button_count_temp = 0 # Reset the counter once submission is accepted
108
+
109
+ # Randomly generate a number from 1 to 100
110
+ st.session_state.fortune_number = random.randint(1, 100)
111
+
112
+ # Look up the row in the CSV where CNumber matches the generated fortune number.
113
+ df = st.session_state.fortune_data
114
+ row_detail = ''
115
+ if df is not None:
116
+ matching_row = df[df['CNumber'] == st.session_state.fortune_number]
117
+ if not matching_row.empty:
118
+ row = matching_row.iloc[0]
119
+ row_detail = row.get("Detail", "No detail available.")
120
+ st.session_state.fortune_row = {
121
+ "Header": row.get("Header", "N/A"),
122
+ "Luck": row.get("Luck", "N/A"),
123
+ "Description": row.get("Description", "No description available."),
124
+ "Detail": row_detail,
125
+ "HeaderLink": row.get("link", None)
126
+ }
127
+ else:
128
+ st.session_state.fortune_row = {
129
+ "Header": "N/A",
130
+ "Luck": "N/A",
131
+ "Description": "No description available.",
132
+ "Detail": "No detail available.",
133
+ "HeaderLink": None
134
+ }
135
+ print(row_detail)
136
+ classifiy = load_finetuned_classifier_model(question)
137
+ cfu_explain = analysis(row_detail, classifiy, question)
138
+ # Save the returned value in session state for later display
139
+ st.session_state.cfu_explain_text = cfu_explain
140
+
141
+ def load_and_resize_image(path, max_size=MAX_SIZE):
142
+ try:
143
+ img = Image.open(path)
144
+ img.thumbnail(max_size, Image.Resampling.LANCZOS)
145
+ return img
146
+ except Exception as e:
147
+ st.error(f"Error loading image: {e}")
148
+ return None
149
+
150
+ def download_and_resize_image(url, max_size=MAX_SIZE):
151
+ try:
152
+ response = requests.get(url)
153
+ response.raise_for_status()
154
+ image_bytes = BytesIO(response.content)
155
+ img = Image.open(image_bytes)
156
+ img.thumbnail(max_size, Image.Resampling.LANCZOS)
157
+ return img
158
+ except Exception as e:
159
+ st.error(f"Error loading image from URL: {e}")
160
+ return None
161
+
162
+ def stick_enquiry_callback():
163
+ st.session_state.stick_clicked = True
164
+
165
+ # Main layout: Left (input) and Right (fortune display)
166
+ left_col, _, right_col = st.columns([3, 1, 5])
167
+
168
+ # ---- Left Column ----
169
+ with left_col:
170
+ left_top = st.container()
171
+ left_bottom = st.container()
172
+ with left_top:
173
+ st.text_area("Enter your question in English", key="user_sentence", height=150)
174
+ st.button("submit", key="submit_button", on_click=submit_text_callback)
175
+ if st.session_state.error_message:
176
+ st.error(st.session_state.error_message)
177
+ if st.session_state.submitted_text:
178
+ with left_bottom:
179
+ for _ in range(5):
180
+ st.write("")
181
+ col1, col2, col3 = st.columns(3)
182
+ with col2:
183
+ st.button("Cfu Explain", key="stick_button", on_click=stick_enquiry_callback)
184
+ if st.session_state.stick_clicked:
185
+ # Display the explanation text saved from analysis()
186
+ st.text_area(' ', value=st.session_state.cfu_explain_text, height=300, disabled=True)
187
+
188
+ # ---- Right Column ----
189
+ with right_col:
190
+ with st.container():
191
+ col_left, col_center, col_right = st.columns([1, 2, 1])
192
+ with col_center:
193
+ if st.session_state.submitted_text and st.session_state.fortune_row:
194
+ header_link = st.session_state.fortune_row.get("HeaderLink")
195
+ if header_link:
196
+ img_from_url = download_and_resize_image(header_link)
197
+ if img_from_url:
198
+ st.image(img_from_url, use_container_width=False)
199
+ else:
200
+ img = load_and_resize_image("/home/user/app/error.png")
201
+ if img:
202
+ st.image(img, use_container_width=False)
203
+ else:
204
+ img = load_and_resize_image("/home/user/app/error.png")
205
+ if img:
206
+ st.image(img, use_container_width=False)
207
+ else:
208
+ img = load_and_resize_image("/home/user/app/fortune.png")
209
+ if img:
210
+ st.image(img, caption="Your Fortune", use_container_width=False)
211
+ with st.container():
212
+ if st.session_state.fortune_row:
213
+ luck_text = st.session_state.fortune_row.get("Luck", "N/A")
214
+ description_text = st.session_state.fortune_row.get("Description", "No description available.")
215
+ detail_text = st.session_state.fortune_row.get("Detail", "No detail available.")
216
+
217
+ summary = f"""
218
+ <div style="font-size: 28px; font-weight: bold;">
219
+ Fortune stick number: {st.session_state.fortune_number}<br>
220
+ Luck: {luck_text}
221
+ </div>
222
+ """
223
+ st.markdown(summary, unsafe_allow_html=True)
224
+
225
+ st.text_area("Description", value=description_text, height=150, disabled=True)
226
+ st.text_area("Detail", value=detail_text, height=150, disabled=True)