blitt commited on
Commit
339ffb0
·
1 Parent(s): a5a0f50

Creating README.md

Browse files
Files changed (1) hide show
  1. README.md +89 -0
README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
3
+ # Doc / guide: https://huggingface.co/docs/hub/model-cards
4
+ {}
5
+ ---
6
+
7
+ ### Model Description
8
+
9
+ This model takes in text from a news article and outputs an embedding representing that article. These output embeddings have been trained such that the cosine similarity between articles aligns with overall article similarity.
10
+ The model was trained using data from the 2022 SemEval Task-8 News Article Similarity challenge, and achieves the second-highest score when evaluated using the test set from the challenge.
11
+ Designed for speed and scalability, this model is ideal for embedding many news articles (or similar text) and using fast cosine similarity calculations for pairwise similarity over very large corpora.
12
+
13
+
14
+ - **Developed by:** Ben Litterer, David Jurgens, Dallas Card
15
+ - **Finetuned from model [optional]:** all-mpnet-base-v2
16
+
17
+
18
+ ## Uses
19
+
20
+ This model is ideal for embedding large corpora of text and calculating pairwise similarity scores.
21
+ Note that when training, article headlines were first concatenated to the full article text. The first 288 tokens and the last 96 tokens were then concatenated to fit in the all-mpnet-base-v2 context window.
22
+
23
+
24
+ ## How to Get Started with the Model
25
+
26
+ Use the code below to get started with the model. All you need are the weights in `state_dict.tar`
27
+
28
+ ```
29
+ import torch
30
+ import torch.nn
31
+ import matplotlib.pyplot as plt
32
+ from transformers import AutoTokenizer, AutoModel
33
+ import numpy as np
34
+
35
+ MODEL_PATH = "/my/path/to/state_dict.tar"
36
+
37
+ #declare model class, inheriting from nn.Module
38
+ class BiModel(torch.nn.Module):
39
+ def __init__(self):
40
+ super(BiModel,self).__init__()
41
+ self.model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to(device).train()
42
+ self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-4)
43
+
44
+ #pool token level embeddings
45
+ def mean_pooling(self, token_embeddings, attention_mask):
46
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
47
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
48
+
49
+ #Note that here we expect only one batch of input ids and attention masks
50
+ def encode(self, input_ids, attention_mask):
51
+ encoding = self.model(input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1))[0]
52
+ meanPooled = self.mean_pooling(encoding, attention_mask.squeeze(1))
53
+ return meanPooled
54
+
55
+ #NOTE: here we expect a list of two that we then unpack
56
+ def forward(self, input_ids, attention_mask):
57
+
58
+ input_ids_a = input_ids[0].to(device)
59
+ input_ids_b = input_ids[1].to(device)
60
+ attention_a = attention_mask[0].to(device)
61
+ attention_b = attention_mask[1].to(device)
62
+
63
+ #encode sentence and get mean pooled sentence representation
64
+ encoding1 = self.model(input_ids_a, attention_mask=attention_a)[0] #all token embeddings
65
+ encoding2 = self.model(input_ids_b, attention_mask=attention_b)[0]
66
+
67
+ meanPooled1 = self.mean_pooling(encoding1, attention_a)
68
+ meanPooled2 = self.mean_pooling(encoding2, attention_b)
69
+
70
+ pred = self.cos(meanPooled1, meanPooled2)
71
+ return pred
72
+
73
+ #set device as needed, initialize model, load weights
74
+ device = torch.device("cpu")
75
+ trainedModel = BiModel()
76
+ sDict = torch.load(MODEL_PATH)
77
+
78
+ #may need to run depending on pytorch version
79
+ del sDict["model.embeddings.position_ids"]
80
+
81
+ #initialize tokenizer for all-mpnet-base-v2
82
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
83
+
84
+ #initialize model
85
+ trainedModel.load_state_dict(sDict)
86
+
87
+ #trainedModel is now ready to use
88
+ ```
89
+