zerostratos commited on
Commit
5f49389
·
verified ·
1 Parent(s): 1373cd8

Update streamlitapp.py

Browse files
Files changed (1) hide show
  1. streamlitapp.py +195 -0
streamlitapp.py CHANGED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import transformers
5
+ from transformers import AutoTokenizer,AutoModel
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import matplotlib.pyplot as plt
9
+ import torch.nn.functional as F
10
+
11
+ class BCNN(nn.Module):
12
+ def __init__(self, embedding_dim, output_dim,
13
+ dropout,bidirectional_units,conv_filters):
14
+
15
+ super().__init__()
16
+ self.bert = AutoModel.from_pretrained('vinai/phobert-base-v2')
17
+ #.fc_input = nn.Linear(embedding_dim,embedding_dim)
18
+ self.bidirectional_lstm = nn.LSTM(
19
+ embedding_dim, bidirectional_units, bidirectional=True, batch_first=True
20
+ )
21
+ self.conv1 = nn.Conv1d(in_channels=2*bidirectional_units, out_channels=conv_filters[0], kernel_size=4)
22
+ self.conv2 = nn.Conv1d(in_channels=2*bidirectional_units, out_channels=conv_filters[1], kernel_size=5)
23
+
24
+ self.fc = nn.Linear(64, output_dim)
25
+
26
+ self.dropout = nn.Dropout(dropout)
27
+
28
+ def forward(self,b_input_ids,b_input_mask):
29
+ encoded = self.bert(b_input_ids,b_input_mask)[0]
30
+ embedded, _ = self.bidirectional_lstm(encoded)
31
+ embedded = embedded.permute(0, 2, 1)
32
+ conved_1 = F.relu(self.conv1(embedded))
33
+ conved_2 = F.relu(self.conv2(embedded))
34
+ #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]
35
+
36
+ pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)
37
+ pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)
38
+ #pooled_n = [batch size, n_fibatlters]
39
+
40
+ cat = self.dropout(torch.cat((pooled_1, pooled_2), dim = 1))
41
+
42
+ #cat = [batch size, n_filters * len(filter_sizes)]
43
+
44
+ result = self.fc(cat)
45
+
46
+ return result
47
+
48
+ class TextClassificationApp:
49
+ def __init__(self, model_path, class_names, model_name='vinai/phobert-base-v2'):
50
+ """
51
+ Initialize Streamlit Text Classification App
52
+
53
+ Args:
54
+ model_path (str): Path to the pre-trained .pt model file
55
+ class_names (list): List of classification labels
56
+ model_name (str): Hugging Face model name for tokenization
57
+ """
58
+ # Set up Streamlit page
59
+ st.set_page_config(
60
+ page_title="Text Classification",
61
+ page_icon="📝",
62
+ layout="wide"
63
+ )
64
+
65
+ # Device configuration
66
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
+
68
+ # Load tokenizer
69
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
70
+
71
+ # Load the model
72
+ EMBEDDING_DIM = 768
73
+ OUTPUT_DIM = 2
74
+ DROPOUT = 0.1
75
+ CONV_FILTERS = [32, 32] # Number of filters for each kernel size (4 and 5)
76
+ BIDIRECTIONAL_UNITS = 128
77
+ self.model = BCNN(EMBEDDING_DIM, OUTPUT_DIM, DROPOUT, BIDIRECTIONAL_UNITS, CONV_FILTERS)
78
+ self.model = torch.load(r'toxic.pt',map_location=torch.device('cpu'))
79
+ self.model.eval() # Set to evaluation mode
80
+
81
+ # Store class names
82
+ self.class_names = class_names
83
+
84
+ # Maximum sequence length
85
+ self.max_length = 128
86
+
87
+ def preprocess_text(self, text):
88
+ """
89
+ Preprocess input text for model prediction
90
+
91
+ Args:
92
+ text (str): Input text to classify
93
+
94
+ Returns:
95
+ torch.Tensor: Tokenized and encoded input
96
+ """
97
+ # Tokenize and encode the text
98
+ input_ids = []
99
+ attention_masks = []
100
+ encoded = self.tokenizer.encode_plus(
101
+ text,
102
+ add_special_tokens=True,
103
+ max_length=self.max_length,
104
+ padding='max_length',
105
+ truncation=True,
106
+ return_tensors='pt'
107
+ )
108
+ input_ids.append(encoded['input_ids'].to(self.device))
109
+ attention_masks.append(encoded['attention_mask'].to(self.device))
110
+ input_ids = torch.cat(input_ids, dim=0).to(self.device)
111
+ attention_masks = torch.cat(attention_masks, dim=0).to(self.device)
112
+ return input_ids, attention_masks
113
+
114
+ def predict(self, text):
115
+ """
116
+ Make prediction on the input text
117
+
118
+ Args:
119
+ text (str): Input text to classify
120
+
121
+ Returns:
122
+ tuple: (predicted class, probabilities)
123
+ """
124
+ # Preprocess the text
125
+ inputs,mask = self.preprocess_text(text)
126
+
127
+ # Disable gradient calculation
128
+ with torch.no_grad():
129
+ # Get model outputs
130
+ outputs = self.model(inputs,mask)
131
+
132
+ # Apply softmax to get probabilities
133
+ probabilities = torch.softmax(outputs, dim=1)
134
+
135
+ # Get top predictions
136
+ top_probs, top_classes = torch.topk(probabilities, k=1)
137
+
138
+ return top_classes[0].cpu().numpy(), top_probs[0].cpu().numpy()
139
+
140
+ def run(self):
141
+ """
142
+ Main Streamlit app runner
143
+ """
144
+ # Title and description
145
+ st.title("📄 Text Classification")
146
+ st.write("Enter text to classify")
147
+
148
+ # Text input
149
+ text_input = st.text_area(
150
+ "Paste your text here",
151
+ height=250,
152
+ placeholder="Enter the text you want to classify..."
153
+ )
154
+
155
+ # Prediction button
156
+ if st.button("Classify Text"):
157
+ if text_input.strip():
158
+ # Make prediction
159
+ top_classes, top_probs = self.predict(text_input)
160
+
161
+ # Display results
162
+ st.subheader("Classification Results")
163
+
164
+ # Create columns for results
165
+ cols = st.columns(3)
166
+
167
+ for i, (cls, prob) in enumerate(zip(top_classes, top_probs)):
168
+ with cols[i]:
169
+ st.metric(
170
+ label=f"Top {i+1} Prediction",
171
+ value=f"{self.class_names[cls]}",
172
+ delta=f"{prob:.2%}"
173
+ )
174
+ # Show input text details
175
+ with st.expander("Input Text Details"):
176
+ st.write("**Original Text:**")
177
+ st.write(text_input)
178
+ st.write(f"**Text Length:** {len(text_input)} characters")
179
+ else:
180
+ st.warning("Please enter some text to classify")
181
+
182
+ def main():
183
+ # Replace these with your actual model path and class names
184
+ MODEL_PATH = '/workspaces/final-project-dl/toxic.pt'
185
+ CLASS_NAMES = [
186
+ 'Non-toxic',
187
+ 'Toxic'
188
+ ]
189
+
190
+ # Initialize and run the app
191
+ app = TextClassificationApp(MODEL_PATH, CLASS_NAMES)
192
+ app.run()
193
+
194
+ if __name__ == "__main__":
195
+ main()