Spaces:
Running
Running
import torch | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
HfArgumentParser | |
) | |
from train_classifier import ClassifierModelArguments | |
from shared import CATEGORIES, DatasetArguments | |
from tqdm import tqdm | |
from preprocess import get_words, clean_text | |
from segment import extract_segment | |
import os | |
import json | |
import numpy as np | |
def softmax(_outputs): | |
maxes = np.max(_outputs, axis=-1, keepdims=True) | |
shifted_exp = np.exp(_outputs - maxes) | |
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) | |
def main(): | |
# See all possible arguments in src/transformers/training_args.py | |
# or by passing the --help flag to this script. | |
# We now keep distinct sets of args, for a cleaner separation of concerns. | |
parser = HfArgumentParser((ClassifierModelArguments, DatasetArguments)) | |
model_args, dataset_args = parser.parse_args_into_dataclasses() | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_args.model_name_or_path) | |
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) | |
processed_db_path = os.path.join( | |
dataset_args.data_dir, dataset_args.processed_database) | |
with open(processed_db_path) as fp: | |
data = json.load(fp) | |
mapped_categories = { | |
str(v).lower(): k for k, v in enumerate(CATEGORIES) | |
} | |
for video_id, segments in tqdm(data.items()): | |
words = get_words(video_id) | |
if not words: | |
continue # No/empty transcript for video_id | |
valid_segments = [] | |
texts = [] | |
for segment in segments: | |
segment_words = extract_segment( | |
words, segment['start'], segment['end']) | |
text = clean_text(' '.join(x['text'] for x in segment_words)) | |
duration = segment['end'] - segment['start'] | |
wps = len(segment_words)/duration if duration > 0 else 0 | |
if wps < 1.5: | |
continue | |
# Do not worry about those that are locked or have enough votes | |
if segment['locked']: # or segment['votes'] > 5: | |
continue | |
texts.append(text) | |
valid_segments.append(segment) | |
if not texts: | |
continue # No valid segments | |
model_inputs = tokenizer( | |
texts, return_tensors='pt', padding=True, truncation=True) | |
with torch.no_grad(): | |
model_outputs = model(**model_inputs) | |
outputs = list(map(lambda x: x.numpy(), model_outputs['logits'])) | |
scores = softmax(outputs) | |
for segment, text, score in zip(valid_segments, texts, scores): | |
predicted_index = score.argmax().item() | |
if predicted_index == mapped_categories[segment['category']]: | |
continue # Ignore correct segments | |
a = {k: round(float(score[i]), 3) | |
for i, k in enumerate(CATEGORIES)} | |
del segment['submission_time'] | |
segment.update({ | |
'predicted': str(CATEGORIES[predicted_index]).lower(), | |
'text': text, | |
'scores': a | |
}) | |
print(json.dumps(segment)) | |
if __name__ == "__main__": | |
main() | |