# Produce Movie Reviews with Positive Sentiment

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CarperAI/trlx/blob/main/examples/notebooks/trlx_sentiments.ipynb)

#### Optimize gpt2 to review movies positively based on a corpus of IMDB reviews with sentiment scores from DistilBert.

Notebook by [@zswitten](https://github.com/zswitten)

---

Execute the cells below to install [TRLX](https://github.com/CarperAI/trlx) for a colab environment.

In [None]:
!git clone https://github.com/CarperAI/trlx.git
!git config --global --add safe.directory /content/trlx && cd /content/trlx && pip install -e .

Cloning into 'trlx'...
remote: Enumerating objects: 5140, done.[K
remote: Counting objects: 100% (170/170), done.[K
remote: Compressing objects: 100% (101/101), done.[K
remote: Total 5140 (delta 93), reused 124 (delta 69), pack-reused 4970[K
Receiving objects: 100% (5140/5140), 46.17 MiB | 14.98 MiB/s, done.
Resolving deltas: 100% (3228/3228), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/trlx
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tabulate>=0.9.0
  Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)
Collecting rich
  Downloading rich-13.3.1-py3-none-any.whl (239 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m239.

In [None]:
# uninstall scikit_learn + jax to avoid numpy issues
!pip uninstall -y scikit_learn jax

Found existing installation: scikit-learn 1.0.2
Uninstalling scikit-learn-1.0.2:
  Successfully uninstalled scikit-learn-1.0.2
Found existing installation: jax 0.3.25
Uninstalling jax-0.3.25:
  Successfully uninstalled jax-0.3.25


In [None]:
import os

# run within repo
os.chdir('/content/trlx')
print(os.getcwd())

/content/trlx


In [None]:
import yaml
from datasets import load_dataset
from transformers import pipeline
import pathlib
from typing import Dict, List
import trlx
from trlx.data.default_configs import TRLConfig, default_ilql_config

In [None]:
default_config = default_ilql_config().to_dict()
default_config['train']['tracker'] = None
default_config['train']['batch_size'] = 16
default_config['train']['epochs'] = 10
config = TRLConfig.update(default_config, {})
print(config)

{
    "method": {
        "name": "ilqlconfig",
        "tau": 0.7,
        "gamma": 0.99,
        "cql_scale": 0.1,
        "awac_scale": 1,
        "alpha": 0.001,
        "beta": 0,
        "steps_for_target_q_sync": 5,
        "two_qs": true,
        "gen_kwargs": {
            "max_new_tokens": 56,
            "top_k": 20,
            "beta": 4,
            "temperature": 1.0
        }
    },
    "model": {
        "model_path": "gpt2",
        "model_arch_type": "causal",
        "num_layers_unfrozen": -1,
        "delta_kwargs": null
    },
    "optimizer": {
        "name": "adamw",
        "kwargs": {
            "lr": 5e-05,
            "betas": [
                0.9,
                0.95
            ],
            "eps": 1e-08,
            "weight_decay": 1e-06
        }
    },
    "scheduler": {
        "name": "cosine_annealing",
        "kwargs": {
            "T_max": 1000,
            "eta_min": 5e-05
        }
    },
    "tokenizer": {
        "tokenizer_path": "gpt2",

In [None]:
def get_positive_score(scores):
    "Extract value associated with a positive sentiment from pipeline's output"
    return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]

sentiment_fn = pipeline(
    "sentiment-analysis",
    "lvwerra/distilbert-imdb",
    top_k=2,
    truncation=True,
    batch_size=256,
    device=0,
)

def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]:
    sentiments = list(map(get_positive_score, sentiment_fn(samples)))
    return {"sentiments": sentiments}

imdb = load_dataset("imdb", split="train+test")

Downloading (…)lve/main/config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1...


Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1. Subsequent calls will reuse this data.


In [None]:
trainer = trlx.train(
    samples=imdb["text"], 
    rewards=imdb["label"],
    eval_prompts=[
        "I don't know much about Hungarian underground",
        "What made this movie so distinctly",
        "Like the sandwich I just bought at the grocery store,",
        "I cannot believe how much this movie made me want to"
    ] * 20,
    metric_fn=metric_fn,
    config=config,
)

[RANK 0] Initializing model: gpt2


Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

[RANK 0] Collecting rollouts
Token indices sequence length is longer than the specified maximum sequence length for this model (1169 > 1024). Running this sequence through the model will result in indexing errors
[RANK 0] Logging sample example


[RANK 0] Logging experience string statistics


[RANK 0] Starting training
[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


  0%|          | 0/1000 [00:00<?, ?it/s]

[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/5]:   0%|          | 0/5 [00:00<?, ?it/s]

[RANK 0] Computing metrics
[RANK 0] Summarizing evaluation


In [None]:
# output
input_str = 'One thing you should know about When Sally Met Harry is that'
trainer_output = trainer.generate_eval(
    **trainer.tokenizer(input_str, return_tensors='pt'))[0]
print(trainer.tokenizer.decode(trainer_output))

One thing you should know about When Sally Met Harry is that there's so much more she's done than in this one. She's got more to offer and more to say for a long, long time. She's still very much a mystery in her own right and the movie does a good job of the plot.<
