perman2011 commited on
Commit
cae091c
·
1 Parent(s): ceefc9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -20,7 +20,7 @@ def sentiment_analysis_DB(input):
20
  input,
21
  None,
22
  add_special_tokens=True,
23
- max_length=MAX_LEN,
24
  pad_to_max_length=True,
25
  return_token_type_ids=True
26
  )
@@ -30,13 +30,13 @@ def sentiment_analysis_DB(input):
30
 
31
  # Assuming model_DB is a PyTorch model
32
  output = model_DB(ids, mask, token_type_ids)
 
33
 
34
- final_outputs = output[0].item() # Extract the scalar value
 
35
 
36
- if final_outputs == True:
37
- result = 1
38
- else:
39
- result = 0
40
 
41
  return result
42
 
 
20
  input,
21
  None,
22
  add_special_tokens=True,
23
+ max_length=100,
24
  pad_to_max_length=True,
25
  return_token_type_ids=True
26
  )
 
30
 
31
  # Assuming model_DB is a PyTorch model
32
  output = model_DB(ids, mask, token_type_ids)
33
+ print('Raw output is ', output)
34
 
35
+ sigmoid_output = torch.sigmoid(output)
36
+ print('Sigmoid output is ', sigmoid_output)
37
 
38
+ # Assuming you want to use a threshold of 0.5
39
+ result = 1 if sigmoid_output.item() > 0.5 else 0
 
 
40
 
41
  return result
42