Joshua Lochner commited on
Commit
18c7914
·
1 Parent(s): 787a8df

Delete moderate.py

Browse files
Files changed (1) hide show
  1. src/moderate.py +0 -104
src/moderate.py DELETED
@@ -1,104 +0,0 @@
1
- import torch
2
- from transformers import (
3
- AutoModelForSequenceClassification,
4
- AutoTokenizer,
5
- HfArgumentParser
6
- )
7
-
8
- from train_classifier import ClassifierModelArguments
9
- from shared import CATEGORIES, DatasetArguments
10
- from tqdm import tqdm
11
-
12
- from preprocess import get_words, clean_text
13
- from segment import extract_segment
14
- import os
15
- import json
16
- import numpy as np
17
-
18
-
19
- def softmax(_outputs):
20
- maxes = np.max(_outputs, axis=-1, keepdims=True)
21
- shifted_exp = np.exp(_outputs - maxes)
22
- return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
23
-
24
-
25
- def main():
26
- # See all possible arguments in src/transformers/training_args.py
27
- # or by passing the --help flag to this script.
28
- # We now keep distinct sets of args, for a cleaner separation of concerns.
29
-
30
- parser = HfArgumentParser((ClassifierModelArguments, DatasetArguments))
31
- model_args, dataset_args = parser.parse_args_into_dataclasses()
32
-
33
- model = AutoModelForSequenceClassification.from_pretrained(
34
- model_args.model_name_or_path)
35
- tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
36
-
37
- processed_db_path = os.path.join(
38
- dataset_args.data_dir, dataset_args.processed_database)
39
- with open(processed_db_path) as fp:
40
- data = json.load(fp)
41
-
42
- mapped_categories = {
43
- str(v).lower(): k for k, v in enumerate(CATEGORIES)
44
- }
45
-
46
- for video_id, segments in tqdm(data.items()):
47
-
48
- words = get_words(video_id)
49
-
50
- if not words:
51
- continue # No/empty transcript for video_id
52
-
53
- valid_segments = []
54
- texts = []
55
- for segment in segments:
56
- segment_words = extract_segment(
57
- words, segment['start'], segment['end'])
58
- text = clean_text(' '.join(x['text'] for x in segment_words))
59
-
60
- duration = segment['end'] - segment['start']
61
- wps = len(segment_words)/duration if duration > 0 else 0
62
- if wps < 1.5:
63
- continue
64
-
65
- # Do not worry about those that are locked or have enough votes
66
- if segment['locked']: # or segment['votes'] > 5:
67
- continue
68
-
69
- texts.append(text)
70
- valid_segments.append(segment)
71
-
72
- if not texts:
73
- continue # No valid segments
74
-
75
- model_inputs = tokenizer(
76
- texts, return_tensors='pt', padding=True, truncation=True)
77
-
78
- with torch.no_grad():
79
- model_outputs = model(**model_inputs)
80
- outputs = list(map(lambda x: x.numpy(), model_outputs['logits']))
81
-
82
- scores = softmax(outputs)
83
-
84
- for segment, text, score in zip(valid_segments, texts, scores):
85
- predicted_index = score.argmax().item()
86
-
87
- if predicted_index == mapped_categories[segment['category']]:
88
- continue # Ignore correct segments
89
-
90
- a = {k: round(float(score[i]), 3)
91
- for i, k in enumerate(CATEGORIES)}
92
-
93
- del segment['submission_time']
94
- segment.update({
95
- 'predicted': str(CATEGORIES[predicted_index]).lower(),
96
- 'text': text,
97
- 'scores': a
98
- })
99
-
100
- print(json.dumps(segment))
101
-
102
-
103
- if __name__ == "__main__":
104
- main()