# Refine txt2img Prompts with Human Feedback


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CarperAI/trlx/blob/main/examples/notebooks/trlx_simulacra.ipynb)


#### Optimize a gpt2-based txt2img prompt generator to produce aesthetic prompts using a dataset of (prompt, rating) pairs https://github.com/JD-P/simulacra-aesthetic-captions

Notebook by [@smellslikeml](https://github.com/smellslikeml)

---

Execute the cells below to install [TRLX](https://github.com/CarperAI/trlx) for a colab environment.

In [1]:
!pip install git+https://github.com/CarperAI/trlx

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/CarperAI/trlx
  Cloning https://github.com/CarperAI/trlx to /tmp/pip-req-build-rx6dz42z
  Running command git clone --filter=blob:none --quiet https://github.com/CarperAI/trlx /tmp/pip-req-build-rx6dz42z
  Resolved https://github.com/CarperAI/trlx to commit a66a7da90d3b9d4b74cf968139896d6797a17286
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ray@ https://ray-ci-artifact-branch-public.s3.amazonaws.com/42bb0357a6fb13e4994789c824f3623f32869ad8/tmp/artifacts/.whl/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl
  Using cached https://ray-ci-artifact-branch-public.s3.amazonaws.com/42bb0357a6fb13e4994789c824f3623f32869ad8/tmp/artifacts/.whl/ray-3.0.0.dev0-cp39-cp39-manyli

In [2]:
import os
import sqlite3
from urllib.request import urlretrieve

import trlx

url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite"
dbpath = "sac_public_2022_06_29.sqlite"

if not os.path.exists(dbpath):
  print(f"fetching {dbpath}")
  urlretrieve(url, dbpath)

conn = sqlite3.connect(dbpath)
c = conn.cursor()
c.execute(
    "SELECT prompt, rating FROM ratings "
    "JOIN images ON images.id=ratings.iid "
    "JOIN generations ON images.gid=generations.id "
    "WHERE rating IS NOT NULL;"
)

prompts, ratings = tuple(map(list, zip(*c.fetchall())))

In [3]:
from trlx.data.default_configs import default_ilql_config
config = default_ilql_config().evolve(train=dict(batch_size=32, total_steps=300))

Trlx uses [wandb](https://wandb.ai/) to log results. Make sure to set up an account and use your token to authenticate when prompted after executing the cell below.

In [4]:
model = trlx.train(
    "gpt2",
    config=config,
    samples=prompts,
    rewards=ratings,
    eval_prompts=["<|endoftext|>"] * 64
).model

[RANK 0] Initializing model: gpt2


Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668924749999783, max=1.0…

[RANK 0] Collecting rollouts
[RANK 0] Logging sample example


[RANK 0] Logging experience string statistics


[RANK 0] Starting training
[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/2]:   0%|          | 0/2 [00:00<?, ?it/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
[RANK 0] Summarizing evaluation


  0%|          | 0/300 [00:00<?, ?it/s]

[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/2]:   0%|          | 0/2 [00:00<?, ?it/s]

[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/2]:   0%|          | 0/2 [00:00<?, ?it/s]

[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/2]:   0%|          | 0/2 [00:00<?, ?it/s]

[RANK 0] Summarizing evaluation


[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/2]:   0%|          | 0/2 [00:00<?, ?it/s]

[RANK 0] Summarizing evaluation


In [5]:
# Infer the trained model
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
output = model.generate(**tokenizer(["An astronaut riding a horse"] * 16, return_tensors="pt").to(0))
tokenizer.batch_decode(output, skip_special_tokens=True)

['An astronaut riding a horse, an astronaut on a mountain',
 'An astronaut riding a horse, beautiful portrait, detailed digital art by Artgerm, wlop, Anton Fadeev, rossdraws, Ruan Jia, John',
 'An astronaut riding a horse, character design, realism, artstation, in the style sci-fi illustration, inspired by anime.',
 'An astronaut riding a horse by oskar schlemmer',
 'An astronaut riding a horse through the night on the cover of her book',
 'An astronaut riding a horse, beautiful portrait, detailed digital art by Artgerm, wlop, Anton Fadeev, Ruan Jia, John Conrad Berkey and Robert',
 'An astronaut riding a horse, a space cowboy riding a rocket, an astronaut, an astronaut wearing a sports gear, and a beautiful portrait of an astronaut.',
 'An astronaut riding a horse, painting by Anton Fadeev, featured on artstation',
 'An astronaut riding a horse, 4K digital illustration by John Berkey and James Gurney with Peter Morhbacher, John Berkey, Ross Tran, Artstation,',
 'An astronaut riding a 

In [6]:
# Save the model locally
model.save_pretrained("gpt2-simulacra")

In [7]:
# To upload the model to Hugging Face, login first
from huggingface_hub import notebook_login
notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [9]:
# Upload the model to <your_name>/gpt2-simulacra
from huggingface_hub import create_repo, HfApi

repo_id = create_repo("gpt2-simulacra", private=False, exist_ok=True).repo_id
HfApi().upload_folder(folder_path="gpt2-simulacra", repo_id=repo_id)

pytorch_model.bin:   0%|          | 0.00/1.77G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

'https://huggingface.co/reciprocate/gpt2-simulacra/tree/main/'

In [10]:
# Load the same model now stored on Hugging Face
from trlx.models.modeling_ilql import AutoModelForCausalLMWithILQLHeads
hf_model = AutoModelForCausalLMWithILQLHeads.from_pretrained(repo_id)

Downloading (…)lve/main/config.json:   0%|          | 0.00/907 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.77G [00:00<?, ?B/s]

Some weights of the model checkpoint at reciprocate/gpt2-simulacra were not used when initializing GPT2LMHeadModel: ['ilql_heads.target_q_heads.1.2.bias', 'ilql_heads.q_heads.0.2.bias', 'ilql_heads.target_q_heads.0.2.bias', 'ilql_heads.q_heads.0.0.bias', 'ilql_heads.v_head.2.weight', 'ilql_heads.q_heads.1.0.bias', 'ilql_heads.v_head.0.bias', 'ilql_heads.target_q_heads.0.0.bias', 'ilql_heads.target_q_heads.1.0.bias', 'ilql_heads.v_head.0.weight', 'ilql_heads.v_head.2.bias', 'ilql_heads.q_heads.0.2.weight', 'ilql_heads.target_q_heads.1.2.weight', 'ilql_heads.q_heads.1.2.weight', 'ilql_heads.target_q_heads.0.0.weight', 'ilql_heads.target_q_heads.1.0.weight', 'ilql_heads.q_heads.0.0.weight', 'ilql_heads.q_heads.1.2.bias', 'ilql_heads.target_q_heads.0.2.weight', 'ilql_heads.q_heads.1.0.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model

Downloading (…)neration_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]