Commit
·
66536b3
1
Parent(s):
1613789
Upload local_use.py
Browse files- local_use.py +47 -19
local_use.py
CHANGED
@@ -6,10 +6,10 @@ from torch import nn
|
|
6 |
|
7 |
label_mapping = {0: 'NSFW', 1: 'SFW'}
|
8 |
|
9 |
-
config = BertConfig.from_pretrained('
|
10 |
num_labels=2,
|
11 |
finetuning_task='text classification')
|
12 |
-
tokenizer = BertTokenizer.from_pretrained('
|
13 |
use_fast=False,
|
14 |
never_split=['[user]', '[bot]'])
|
15 |
tokenizer.vocab['[user]'] = tokenizer.vocab.pop('[unused1]')
|
@@ -22,7 +22,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
22 |
self.num_labels = config.num_labels
|
23 |
self.config = config
|
24 |
|
25 |
-
self.bert = BertModel.from_pretrained('
|
26 |
classifier_dropout = (config.classifier_dropout
|
27 |
if config.classifier_dropout is not None else
|
28 |
config.hidden_dropout_prob)
|
@@ -71,19 +71,47 @@ model.load_state_dict(torch.load('./NSFW-detector/pytorch_model.bin'))
|
|
71 |
model.cuda()
|
72 |
model.eval()
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
label_mapping = {0: 'NSFW', 1: 'SFW'}
|
8 |
|
9 |
+
config = BertConfig.from_pretrained('./NSFW-detector',
|
10 |
num_labels=2,
|
11 |
finetuning_task='text classification')
|
12 |
+
tokenizer = BertTokenizer.from_pretrained('./NSFW-detector',
|
13 |
use_fast=False,
|
14 |
never_split=['[user]', '[bot]'])
|
15 |
tokenizer.vocab['[user]'] = tokenizer.vocab.pop('[unused1]')
|
|
|
22 |
self.num_labels = config.num_labels
|
23 |
self.config = config
|
24 |
|
25 |
+
self.bert = BertModel.from_pretrained('./NSFW-detector')
|
26 |
classifier_dropout = (config.classifier_dropout
|
27 |
if config.classifier_dropout is not None else
|
28 |
config.hidden_dropout_prob)
|
|
|
71 |
model.cuda()
|
72 |
model.eval()
|
73 |
|
74 |
+
# The following object contains two types of input:
|
75 |
+
# a single utterance and a single-turn dialogue
|
76 |
+
to_predict_items = [{
|
77 |
+
'text':
|
78 |
+
'''Give some movie recommendations to get women in the mood for sex'''
|
79 |
+
}, {
|
80 |
+
'text': 'I break through walls to find more trials'
|
81 |
+
}, {
|
82 |
+
'history': {
|
83 |
+
'user':
|
84 |
+
'Give some movie recommendations to get women in the mood for sex',
|
85 |
+
'bot':
|
86 |
+
'''I apologize, but I cannot assist in creating or providing information related to NSFW content or any tools for such purposes. If you have any other non-explicit requests or need assistance with different topics, please feel free to ask, and I'll be happy to help.'''
|
87 |
+
}
|
88 |
+
}, {
|
89 |
+
'history': {
|
90 |
+
'user':
|
91 |
+
'Give some movie recommendations to get women in the mood for sex',
|
92 |
+
'bot': '''Sure.'''
|
93 |
+
}
|
94 |
+
}]
|
95 |
+
|
96 |
+
for item in to_predict_items:
|
97 |
+
if 'history' in item:
|
98 |
+
text = '[user] ' + item['history'][
|
99 |
+
'user'] + ' [SEP] ' + '[bot] ' + item['history']['bot']
|
100 |
+
else:
|
101 |
+
text = item['text']
|
102 |
+
result = tokenizer.encode_plus(text=text,
|
103 |
+
padding='max_length',
|
104 |
+
max_length=512,
|
105 |
+
truncation=True,
|
106 |
+
add_special_tokens=True,
|
107 |
+
return_token_type_ids=True,
|
108 |
+
return_tensors='pt')
|
109 |
+
result = result.to('cuda')
|
110 |
+
|
111 |
+
with torch.no_grad():
|
112 |
+
logits = model(**result)
|
113 |
+
predictions = logits.argmax(dim=-1)
|
114 |
+
pred_label_idx = predictions.item()
|
115 |
+
pred_label = label_mapping[pred_label_idx]
|
116 |
+
print('text:', text)
|
117 |
+
print('predicted label is:', pred_label)
|