nishan-chatterjee
commited on
Commit
·
468c17d
1
Parent(s):
933881f
additional comments
Browse files- inference.py +18 -2
inference.py
CHANGED
@@ -3,6 +3,7 @@ import numpy as np
|
|
3 |
import networkx as nx
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
|
|
|
6 |
def _make_logits_consistent(x, R):
|
7 |
c_out = x.unsqueeze(1) + 10
|
8 |
c_out = c_out.expand(len(x), R.shape[1], R.shape[1])
|
@@ -10,8 +11,9 @@ def _make_logits_consistent(x, R):
|
|
10 |
final_out, _ = torch.max(R_batch * c_out, dim=2)
|
11 |
return final_out - 10
|
12 |
|
|
|
13 |
def initialize_model():
|
14 |
-
|
15 |
G = nx.DiGraph()
|
16 |
edges = [
|
17 |
("ROOT", "Logos"),
|
@@ -29,12 +31,17 @@ def initialize_model():
|
|
29 |
]
|
30 |
G.add_edges_from(edges)
|
31 |
|
|
|
|
|
|
|
32 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
33 |
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
34 |
|
|
|
35 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
36 |
model.to(device)
|
37 |
|
|
|
38 |
A = nx.to_numpy_array(G).transpose()
|
39 |
R = np.zeros(A.shape)
|
40 |
np.fill_diagonal(R, 1)
|
@@ -47,7 +54,9 @@ def initialize_model():
|
|
47 |
|
48 |
return tokenizer, model, R, G, device
|
49 |
|
|
|
50 |
def predict_persuasion_labels(text, tokenizer, model, R, G, device):
|
|
|
51 |
encoding = tokenizer.encode_plus(
|
52 |
text,
|
53 |
add_special_tokens=True,
|
@@ -58,17 +67,23 @@ def predict_persuasion_labels(text, tokenizer, model, R, G, device):
|
|
58 |
return_attention_mask=True,
|
59 |
return_tensors="pt",
|
60 |
)
|
61 |
-
|
|
|
62 |
with torch.no_grad():
|
63 |
outputs = model(
|
64 |
input_ids=encoding["input_ids"].to(device),
|
65 |
attention_mask=encoding["attention_mask"].to(device),
|
66 |
)
|
|
|
|
|
67 |
logits = _make_logits_consistent(outputs.logits, R)
|
68 |
logits[:, 0] = -1.0
|
69 |
logits = logits > 0.0
|
|
|
|
|
70 |
complete_predicted_hierarchy = np.array(G.nodes)[logits[0].cpu().nonzero()].flatten().tolist()
|
71 |
|
|
|
72 |
child_only_labels = []
|
73 |
for label in complete_predicted_hierarchy:
|
74 |
if not list(G.successors(label)):
|
@@ -78,6 +93,7 @@ def predict_persuasion_labels(text, tokenizer, model, R, G, device):
|
|
78 |
|
79 |
tokenizer, model, R, G, device = initialize_model()
|
80 |
|
|
|
81 |
def inference(text):
|
82 |
return predict_persuasion_labels(text, tokenizer, model, R, G, device)
|
83 |
|
|
|
3 |
import networkx as nx
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
|
6 |
+
# Function to make logits consistent based on the hierarchy matrix R
|
7 |
def _make_logits_consistent(x, R):
|
8 |
c_out = x.unsqueeze(1) + 10
|
9 |
c_out = c_out.expand(len(x), R.shape[1], R.shape[1])
|
|
|
11 |
final_out, _ = torch.max(R_batch * c_out, dim=2)
|
12 |
return final_out - 10
|
13 |
|
14 |
+
# Function to initialize the model, tokenizer, and hierarchy matrix
|
15 |
def initialize_model():
|
16 |
+
# Define the hierarchy graph
|
17 |
G = nx.DiGraph()
|
18 |
edges = [
|
19 |
("ROOT", "Logos"),
|
|
|
31 |
]
|
32 |
G.add_edges_from(edges)
|
33 |
|
34 |
+
# model and tokenizer is saved in the current directory
|
35 |
+
model_dir = "."
|
36 |
+
# loading the model and tokenizer
|
37 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
38 |
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
39 |
|
40 |
+
# Set device to GPU if available
|
41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
model.to(device)
|
43 |
|
44 |
+
# Create the hierarchy matrix R based on the graph structure
|
45 |
A = nx.to_numpy_array(G).transpose()
|
46 |
R = np.zeros(A.shape)
|
47 |
np.fill_diagonal(R, 1)
|
|
|
54 |
|
55 |
return tokenizer, model, R, G, device
|
56 |
|
57 |
+
# Function to predict persuasion labels for a given text
|
58 |
def predict_persuasion_labels(text, tokenizer, model, R, G, device):
|
59 |
+
# Tokenize and encode the input text
|
60 |
encoding = tokenizer.encode_plus(
|
61 |
text,
|
62 |
add_special_tokens=True,
|
|
|
67 |
return_attention_mask=True,
|
68 |
return_tensors="pt",
|
69 |
)
|
70 |
+
|
71 |
+
# Forward pass through the model
|
72 |
with torch.no_grad():
|
73 |
outputs = model(
|
74 |
input_ids=encoding["input_ids"].to(device),
|
75 |
attention_mask=encoding["attention_mask"].to(device),
|
76 |
)
|
77 |
+
|
78 |
+
# Make logits consistent based on the hierarchy matrix R
|
79 |
logits = _make_logits_consistent(outputs.logits, R)
|
80 |
logits[:, 0] = -1.0
|
81 |
logits = logits > 0.0
|
82 |
+
|
83 |
+
# Get the complete predicted hierarchy of labels
|
84 |
complete_predicted_hierarchy = np.array(G.nodes)[logits[0].cpu().nonzero()].flatten().tolist()
|
85 |
|
86 |
+
# Get the child-only labels (labels without any successors)
|
87 |
child_only_labels = []
|
88 |
for label in complete_predicted_hierarchy:
|
89 |
if not list(G.successors(label)):
|
|
|
93 |
|
94 |
tokenizer, model, R, G, device = initialize_model()
|
95 |
|
96 |
+
# Main inference function
|
97 |
def inference(text):
|
98 |
return predict_persuasion_labels(text, tokenizer, model, R, G, device)
|
99 |
|