Commit
·
4a4406e
1
Parent(s):
edec6bb
Update app.py
Browse files
app.py
CHANGED
@@ -24,16 +24,20 @@ def sentiment_analysis_DB(input):
|
|
24 |
pad_to_max_length=True,
|
25 |
return_token_type_ids=True
|
26 |
)
|
27 |
-
ids = inputs['input_ids']
|
28 |
-
mask = inputs['attention_mask']
|
29 |
-
token_type_ids = inputs["token_type_ids"]
|
|
|
|
|
30 |
output = model_DB(ids, mask, token_type_ids)
|
31 |
-
|
32 |
-
final_outputs =
|
|
|
33 |
if final_outputs == True:
|
34 |
result = 1
|
35 |
else:
|
36 |
result = 0
|
|
|
37 |
return result
|
38 |
|
39 |
# Streamlit app
|
|
|
24 |
pad_to_max_length=True,
|
25 |
return_token_type_ids=True
|
26 |
)
|
27 |
+
ids = torch.tensor([inputs['input_ids']]) # Convert to PyTorch tensor
|
28 |
+
mask = torch.tensor([inputs['attention_mask']]) # Convert to PyTorch tensor
|
29 |
+
token_type_ids = torch.tensor([inputs["token_type_ids"]]) # Convert to PyTorch tensor
|
30 |
+
|
31 |
+
# Assuming model_DB is a PyTorch model
|
32 |
output = model_DB(ids, mask, token_type_ids)
|
33 |
+
|
34 |
+
final_outputs = output[0].item() # Extract the scalar value
|
35 |
+
|
36 |
if final_outputs == True:
|
37 |
result = 1
|
38 |
else:
|
39 |
result = 0
|
40 |
+
|
41 |
return result
|
42 |
|
43 |
# Streamlit app
|