ArneBinder commited on
Commit
12537b9
·
verified ·
1 Parent(s): 7dda167

delete outdated files

Browse files
Files changed (1) hide show
  1. embedding.py +0 -157
embedding.py DELETED
@@ -1,157 +0,0 @@
1
- import abc
2
- import logging
3
- from typing import Dict, Union
4
-
5
- import torch
6
- from datasets import Dataset
7
- from pie_modules.document.processing import tokenize_document
8
- from pie_modules.documents import (
9
- TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
10
- TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
11
- )
12
- from pytorch_ie.annotations import LabeledSpan, MultiSpan, Span
13
- from pytorch_ie.documents import TextBasedDocument
14
- from torch import FloatTensor, Tensor
15
- from torch.utils.data import DataLoader
16
- from transformers import AutoModel, AutoTokenizer
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
-
21
- class EmbeddingModel(abc.ABC):
22
- def __call__(
23
- self, document: TextBasedDocument, span_layer_name: str
24
- ) -> Dict[Union[Span, MultiSpan], FloatTensor]:
25
- """Embed text annotations from a document.
26
-
27
- Args:
28
- document: The document to embed.
29
- span_layer_name: The name of the annotation layer in the document that contains the
30
- text span annotations to embed.
31
-
32
- Returns:
33
- A dictionary mapping text annotations to their embeddings.
34
- """
35
- pass
36
-
37
-
38
- class HuggingfaceEmbeddingModel(EmbeddingModel):
39
- def __init__(
40
- self,
41
- model_name_or_path: str,
42
- revision: str = None,
43
- device: str = "cpu",
44
- max_length: int = 512,
45
- batch_size: int = 16,
46
- ):
47
- self.load(model_name_or_path, revision, device)
48
- self.max_length = max_length
49
- self.batch_size = batch_size
50
-
51
- def load(self, model_name_or_path: str, revision: str = None, device: str = "cpu") -> None:
52
- self._model = AutoModel.from_pretrained(model_name_or_path, revision=revision).to(device)
53
- self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, revision=revision)
54
-
55
- def __call__(
56
- self, document: TextBasedDocument, span_layer_name: str
57
- ) -> Dict[Union[Span, MultiSpan], FloatTensor]:
58
- # to not modify the original document
59
- document = document.copy()
60
- # tokenize_document does not yet consider predictions, so we need to add them manually
61
- document[span_layer_name].extend(document[span_layer_name].predictions.clear())
62
- added_annotations = []
63
- tokenizer_kwargs = {
64
- "max_length": self.max_length,
65
- "stride": self.max_length // 8,
66
- "truncation": True,
67
- "padding": True,
68
- "return_overflowing_tokens": True,
69
- }
70
- # tokenize once to get the tokenized documents with mapped annotations
71
- span_annotation_type = document.annotation_types()[span_layer_name]
72
- if issubclass(span_annotation_type, Span):
73
- result_document_type = TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
74
- tokenized_span_layer_name = "labeled_spans"
75
- elif issubclass(span_annotation_type, MultiSpan):
76
- result_document_type = (
77
- TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
78
- )
79
- tokenized_span_layer_name = "labeled_multi_spans"
80
- else:
81
- raise ValueError(f"Unsupported annotation type: {span_annotation_type}")
82
- tokenized_documents = tokenize_document(
83
- document,
84
- tokenizer=self._tokenizer,
85
- result_document_type=result_document_type,
86
- partition_layer="labeled_partitions",
87
- added_annotations=added_annotations,
88
- strict_span_conversion=False,
89
- **tokenizer_kwargs,
90
- )
91
-
92
- # just tokenize again to get tensors in the correct format for the model
93
- dataset = Dataset.from_dict({"text": [document.text]})
94
-
95
- def tokenize_function(examples):
96
- return self._tokenizer(examples["text"], **tokenizer_kwargs)
97
-
98
- # Tokenize the texts. Note that we remove the text column directly in the map call,
99
- # otherwise the map would fail because we produce we amy produce multipel new rows
100
- # (tokenization result) for each input row (text).
101
- tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
102
- # remove the overflow_to_sample_mapping column
103
- tokenized_dataset = tokenized_dataset.remove_columns(["overflow_to_sample_mapping"])
104
- tokenized_dataset.set_format(type="torch")
105
-
106
- dataloader = DataLoader(tokenized_dataset, batch_size=self.batch_size)
107
-
108
- embeddings = {}
109
- example_idx = 0
110
- for batch in dataloader:
111
- batch_at_device = {
112
- k: v.to(self._model.device) if isinstance(v, Tensor) else v
113
- for k, v in batch.items()
114
- }
115
- with torch.no_grad():
116
- model_output = self._model(**batch_at_device)
117
-
118
- for last_hidden_state in model_output.last_hidden_state:
119
- text2tok_ann = added_annotations[example_idx][span_layer_name]
120
- tok2text_ann = {v: k for k, v in text2tok_ann.items()}
121
- for tok_ann in tokenized_documents[example_idx][tokenized_span_layer_name]:
122
- if isinstance(tok_ann, LabeledSpan):
123
- # skip "empty" annotations
124
- if tok_ann.start == tok_ann.end:
125
- continue
126
-
127
- embedded_tokens = last_hidden_state[tok_ann.start : tok_ann.end]
128
-
129
- elif isinstance(tok_ann, MultiSpan):
130
- # skip "empty" annotations
131
- if all(start == end for start, end in tok_ann.slices):
132
- continue
133
-
134
- # concatenate the embeddings of the tokens that make up the multi-span
135
- embedded_tokens = torch.concat(
136
- [
137
- last_hidden_state[start:end]
138
- for start, end in tok_ann.slices
139
- if start != end
140
- ],
141
- dim=0,
142
- )
143
- else:
144
- raise ValueError(f"Unsupported annotation type: {type(tok_ann)}")
145
- # use the max pooling strategy to get a single embedding for the annotation text
146
- embedding = embedded_tokens.max(dim=0)[0].detach().cpu()
147
-
148
- text_ann = tok2text_ann[tok_ann]
149
-
150
- # if text_ann in embeddings:
151
- # logger.warning(
152
- # f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
153
- # )
154
- embeddings[text_ann] = embedding
155
- example_idx += 1
156
-
157
- return embeddings