Valeriy Sinyukov commited on
Commit
d8f96d2
·
1 Parent(s): a1ad5de

Factor-out pipeline

Browse files
category_classification/models/HibiscusMaximus__scibert_paper_classification/model.py CHANGED
@@ -1,55 +1,7 @@
1
- import typing as tp
2
-
3
- import torch
4
- from transformers import pipeline, Pipeline, AutoModelForSequenceClassification
5
- from transformers.pipelines import PIPELINE_REGISTRY
6
 
7
  name = "HibiscusMaximus/scibert_paper_classification"
8
 
9
-
10
- class SciBertPaperClassifierPipeline(Pipeline):
11
- def _sanitize_parameters(self, **kwargs):
12
- return {}, {}, {}
13
-
14
- def preprocess(self, inputs):
15
- if not isinstance(inputs, tp.Iterable):
16
- inputs = [inputs]
17
- texts = [
18
- f"AUTHORS: {' '.join(paper.authors) if isinstance(paper.authors, list) else paper.authors} "
19
- f"TITLE: {paper.title} ABSTRACT: {paper.abstract}"
20
- for paper in inputs
21
- ]
22
- inputs = self.tokenizer(
23
- texts, truncation=True, padding=True, max_length=256, return_tensors="pt"
24
- ).to(self.device)
25
- return inputs
26
-
27
- def _forward(self, model_inputs):
28
- with torch.no_grad():
29
- outputs = self.model(**model_inputs)
30
- return outputs
31
-
32
- def postprocess(self, model_outputs):
33
- probs = torch.nn.functional.softmax(model_outputs.logits, dim=-1)
34
- results = []
35
- for prob in probs:
36
- result = [
37
- {"label": self.model.config.id2label[label_idx], "score": score.item()}
38
- for label_idx, score in enumerate(prob)
39
- ]
40
- results.append(result)
41
- if 1 == len(results):
42
- return results[0]
43
- return results
44
-
45
-
46
- PIPELINE_REGISTRY.register_pipeline(
47
- "paper-classification",
48
- pipeline_class=SciBertPaperClassifierPipeline,
49
- pt_model=AutoModelForSequenceClassification,
50
- )
51
-
52
-
53
  class SciBertPaperClassifier:
54
  def __init__(self):
55
  self.pipeline = pipeline("paper-classification", model=name)
 
1
+ from transformers import pipeline
 
 
 
 
2
 
3
  name = "HibiscusMaximus/scibert_paper_classification"
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class SciBertPaperClassifier:
6
  def __init__(self):
7
  self.pipeline = pipeline("paper-classification", model=name)
category_classification/models/models.py CHANGED
@@ -5,6 +5,7 @@ import typing as tp
5
  import warnings
6
  from pathlib import Path
7
 
 
8
 
9
  def import_model_module(file_path: os.PathLike):
10
  module_name = str(Path(file_path).relative_to(os.getcwd())).replace(
 
5
  import warnings
6
  from pathlib import Path
7
 
8
+ from . import pipeline
9
 
10
  def import_model_module(file_path: os.PathLike):
11
  module_name = str(Path(file_path).relative_to(os.getcwd())).replace(
category_classification/models/pipeline.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+ from transformers import Pipeline, AutoModelForSequenceClassification
6
+ from transformers.pipelines import PIPELINE_REGISTRY
7
+
8
+ class PapersClassificationPipeline(Pipeline):
9
+ def _sanitize_parameters(self, **kwargs):
10
+ return {}, {}, {}
11
+
12
+ def preprocess(self, inputs):
13
+ if not isinstance(inputs, tp.Iterable):
14
+ inputs = [inputs]
15
+ texts = [
16
+ f"AUTHORS: {' '.join(paper.authors) if isinstance(paper.authors, list) else paper.authors} "
17
+ f"TITLE: {paper.title} ABSTRACT: {paper.abstract}"
18
+ for paper in inputs
19
+ ]
20
+ inputs = self.tokenizer(
21
+ texts, truncation=True, padding=True, max_length=256, return_tensors="pt"
22
+ ).to(self.device)
23
+ return inputs
24
+
25
+ def _forward(self, model_inputs):
26
+ with torch.no_grad():
27
+ outputs = self.model(**model_inputs)
28
+ return outputs
29
+
30
+ def postprocess(self, model_outputs):
31
+ probs = torch.nn.functional.softmax(model_outputs.logits, dim=-1)
32
+ results = []
33
+ for prob in probs:
34
+ result = [
35
+ {"label": self.model.config.id2label[label_idx], "score": score.item()}
36
+ for label_idx, score in enumerate(prob)
37
+ ]
38
+ results.append(result)
39
+ if 1 == len(results):
40
+ return results[0]
41
+ return results
42
+
43
+
44
+ PIPELINE_REGISTRY.register_pipeline(
45
+ "paper-classification",
46
+ pipeline_class=PapersClassificationPipeline,
47
+ pt_model=AutoModelForSequenceClassification,
48
+ )