File size: 4,548 Bytes
c4ebaf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# import os 

# os.chdir('naacl-2021-fudge-controlled-generation/')

import gradio as gr
from fudge.predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer
from datasets import load_dataset,DatasetDict,Dataset
# from datasets import 
from transformers import AutoTokenizer,AutoModelForSeq2SeqLM
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight
import torch
import pandas as pd 
from fudge.model import Model
import os
from argparse import ArgumentParser
from collections import namedtuple
import mock

from tqdm import tqdm
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from fudge.data import Dataset
from fudge.util import save_checkpoint, ProgressMeter, AverageMeter, num_params
from fudge.constants import *


device = 'cpu'
# imp.reload(model)
pretrained_model = "checkpoint-150/"
generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device)


pad_id = 0

generation_model.eval()

model_args = mock.Mock()
model_args.task = 'clickbait'
model_args.device = device
model_args.checkpoint = 'checkpoint-1464/'

# conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
conditioning_model = conditioning_model.to(device)
conditioning_model.eval()

condition_lambda = 5.0
length_cutoff = 50
precondition_topk = 200


conditioning_model.classifier

model_args.checkpoint

classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True)


def rate_title(input_text, model, tokenizer, device='cuda'):
  # input_text = {
  #                 "postText": input_text['postText'],
  #                 "truthClass" : input_text['truthClass']
  #              }
  tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer)
  # print(tokenized_input.items())
  dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'}
  predicted_class = float(model(**dict_tokenized_input).logits)
  actual_class = input_text['truthClass']

  # print(predicted_class, actual_class)
  return {'predicted_class' : predicted_class}

def preprocess_function_title_only_classification(examples,tokenizer=None):
    model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25)
      
    model_inputs['labels'] = examples['truthClass']

    return model_inputs



def clickbait_generator(article_content, condition_lambda=5.0):
    # result = "Hi {}! 😎. The Mulitple of {} is {}".format(name, number, round(number**2, 2))
    results = generate_clickbait(model=generation_model, 
                        tokenizer=tokenizer, 
                        conditioning_model=conditioning_model, 
                        input_text=[None], 
                        dataset_info=None, 
                        precondition_topk=precondition_topk,
                        length_cutoff=length_cutoff,
                        condition_lambda=condition_lambda,
                        article_content=article_content,
                        device=device)
    
    return results[0].replace('</s>', '').replace('<pad>', '')

title = "Clickbaitinator - Controllable Clickbait generator"
description = """
Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine-tuned for our purposes to try and create news headline you are looking for! Use condition_lambda to steer your clickbaitiness higher (by increasing the slider value) or lower (by decreasing the slider value). <br/>
Note that this is using two Transformers and is executed with CPU-only, so it will take a minute or two to finish generating a title.
"""

article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based of. You need collaborator access, which you have been probably invited for."


app = gr.Interface(
    title = title,
    description = description,
    label = 'Article content or paragraph', 
    fn = clickbait_generator, 
    inputs=["text", gr.Slider(0, 15, step=0.1, value=5.0)],
    outputs="text",
    article=article,
    )
app.launch()