Rick and Morty Story Generator

Photo by Benigno Hoyuela (@benignohoyuela) on Unsplash.com

Introduction

With the rapid progress in Machine Learning (ML) and Natural Language Processing (NLP), new algorithms are able to generate texts that seem more and more human-produced. One such algorithm, GPT21, has been used in many open-source applications2. GPT2 was trained on WebText, which contains 45 million outbound links from Reddit (i.e. websites that comments reference). The top 10 outbound domains3 include Google, Archive, Blogspot, Github, NYTimes, WordPress, Washington Post, Wikia, BBC, and The Guardian. The pre-trained GPT2 model can be fine-tuned on specific datasets, for example, to “acquire” the style of a dataset or learn to classify documents. This is done via Transfer Learning, which can be defined as “a means to extract knowledge from a source setting and apply it to a different target setting”4. For a detailed explanation of GPT2 and its architecture see the original paper5, OpenAI’s blog post6, or Jay Alammar’s illustrated guide7.

Dataset

The dataset used to fine-tune GPT2 consists of the first 3 seasons of Rick and Morty transcripts. The data was downloaded and stored in a raw text format. Each line represents a speaker and their utterance or an action/scene description. The dataset was split into training and test data, which contain 6905 and 1454 lines, respectively. The raw files can found here. The training data is used to fine-tune the model, while the test data is used for evaluation.

Training the model

Hugging Face’s Transformers library provides a simple script to fine-tune a custom GPT2 model. You can fine-tune your own model using this Google Colab notebook. Once your model has finished training, make sure you download the trained model output folder containing all relevant model files (this is essential to load the model later). You can upload your custom model on Hugging Face’s Model Hub8 to make it accessible to the public. The model achieves a perplexity score of around ~17 when evaluated on the test data.

Building the application

To get started, let’s create a new project folder called Story_Generator and a virtual environment for Python 3.7:

mkdir Story_Generator
cd Story_Generator
python3.7 -m venv venv
source venv/bin/activate

Next, we want to install all dependencies for the project:

pip install streamlit-nightly==0.69.3.dev20201025
pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install git+git://github.com/huggingface/transformers@59b5953d89544a66d73

Our entire application will live inside of app.py. Let’s create it and import our newly installed dependencies:

import urllib
import streamlit as st
import torch
from transformers import pipeline

Before we do any processing, we want our model to load. By using the @st.cache decorator, we can execute the load_model() function once and store the result in a local cache. This will boost our application performance. We can then use the pipeline() function to simply load a model for text generation (substitute the model path to your custom model or use my pre-trained model from the Model Hub):

@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_model():
    return pipeline("text-generation", model="e-tony/gpt2-rnm")

model = load_model()

We can use Streamlit’s text_area() function to make a simple text box. We can additionally supply the height and maximum allowed characters (as large text takes longer to produce):

textbox = st.text_area('Start your story:', '', height=200, max_chars=1000)

Now that we have our first lines of code, we can see what that looks like by running the application (we can also see live changes by refreshing the page):

streamlit run app.py

Next, we can add a slider widget to allow the user to determine how many characters the model should generate:

slider = slider = st.slider('Max story length (in characters)', 50, 200)

We are now ready to generate text! Let’s create a button that executes the text generation:

button = st.button('Generate')

We want our application to listen for the “button press” action. This can be done with a simple conditional statement. We can then generate text and output it onto the screen:

if button:
    output_text = model(textbox, max_length=slider)[0]['generated_text']
	
    for i, line in enumerate(output_text.split("\n")):
        if ":" in line:
            speaker, speech = line.split(':')
            st.markdown(f'__{speaker}__: {speech}')
        else:
            st.markdown(line)

Let’s input a prompt into the text box and generate a story:

Rick: Come on, flip the pickle, Morty. You're not gonna regret it. The payoff is huge.

Output:

Rick: Come on, flip the pickle, Morty. You're not gonna regret it. The payoff is huge. You don't have to be bad, Morty.
(Rick breaks up)
[Trans. Ext. Mortys home]

Great! The model is outputting new text and it looks decent. We can improve the quality of the output by adjusting the parameters to the decoding method. See Hugging Face’s post on decoding9 for a detailed overview of different methods. Let’s replace our model() function and apply a few more arguments:

output_text = model(textbox, do_sample=True, max_length=slider, top_k=50, top_p=0.95, num_returned_sequences=1)[0]['generated_text']

In short, do_sample randomly picks the next word, top_k filters the most likely k next words, top_p allows dynamic increase and decrease of the number of possible next words, and num_returned_sequences outputs multiple independent samples (in our case just 1) for further filtering or evaluation. You can play around with the values to get different types of outputs.

Let’s generate another output using this decoding method.

Output:

Rick: Come on, flip the pickle, Morty. You're not gonna regret it. The payoff is huge.
Morty: Ew, no, Rick! Where are you?
Rick: Morty, just do it! [laughing] Just flip the pickle!
Morty: I'm a Morty, okay?
Rick: Come on, Morty. Don't be ashamed to be a Morty. Just flip the pickle.

Our output looks better! The model still generates illogical and nonsensical text, but newer models and decoding methods may solve that problem.

Unfortunately, our model sometimes generates hurtful, vulgar, violent, or discriminating language, as it was trained on data from the Internet. We can apply a bad-word filter by simply checking for vulgar words from a list of 451 words to censor harmful language. I urge the reader to consider applying further filters, such as for hate speech. The filter can be implemented as follows:

def load_bad_words() -> list:
    res_list = []

    file = urllib.request.urlopen("https://raw.githubusercontent.com/RobertJGabriel/Google-profanity-words/master/list.txt")
    for line in file:
        dline = line.decode("utf-8")
        res_list.append(dline.split("\n")[0])
    
    return res_list

BAD_WORDS = load_bad_words()
    
def filter_bad_words(text):
    explicit = False
    
    res_text = text.lower()
    for word in BAD_WORDS:
        if word in res_text:
            res_text = res_text.replace(word, word[0]+"*"*len(word[1:]))
            explicit = True

    if not explicit:
        return text

    output_text = ""
    for oword,rword in zip(text.split(" "), res_text.split(" ")):
        if oword.lower() == rword:
            output_text += oword+" "
        else:
            output_text += rword+" "

    return output_text

output_text = filter_bad_words(model(textbox, do_sample=True, max_length=slider, top_k=50, top_p=0.95, num_returned_sequences=1)[0]['generated_text'])

Our final app.py file now looks like this:

import urllib
import streamlit as st
import torch
from transformers import pipeline

def load_bad_words() -> list:
    res_list = []

    file = urllib.request.urlopen("https://raw.githubusercontent.com/RobertJGabriel/Google-profanity-words/master/list.txt")
    for line in file:
        dline = line.decode("utf-8")
        res_list.append(dline.split("\n")[0])
    
    return res_list

BAD_WORDS = load_bad_words()
    
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_model():
    return pipeline("text-generation", model="e-tony/gpt2-rnm")

def filter_bad_words(text):
    explicit = False
    
    res_text = text.lower()
    for word in BAD_WORDS:
        if word in res_text:
            res_text = res_text.replace(word, word[0]+"*"*len(word[1:]))
            explicit = True

    if not explicit:
        return text

    output_text = ""
    for oword,rword in zip(text.split(" "), res_text.split(" ")):
        if oword.lower() == rword:
            output_text += oword+" "
        else:
            output_text += rword+" "

    return output_text

model = load_model()
textbox = st.text_area('Start your story:', '', height=200, max_chars=1000)
slider = slider = st.slider('Max text length (in characters)', 50, 1000)
button = st.button('Generate')

if button:
    output_text = filter_bad_words(model(textbox, do_sample=True, max_length=slider, top_k=50, top_p=0.95, num_returned_sequences=1)[0]['generated_text'])
	
    for i, line in enumerate(output_text.split("\n")):
        if ":" in line:
            speaker, speech = line.split(':')
            st.markdown(f'__{speaker}__: {speech}')
        else:
            st.markdown(line)

You can additionally check out the code for my demo in the Github repository, as it contains useful code for modifying the functionality and look of the application.

It is now ready to go live!

Deploying the application

The application can be deployed using Streamlit Sharing10. You simply need to have a public Github repository with a requirements.txt and an app.py file in your repository. Your requirements.txt file should look something like this:

-f https://download.pytorch.org/whl/torch_stable.html
streamlit-nightly==0.69.3.dev20201025
torch==1.6.0+cpu
torchvision==0.7.0+cpu
transformers @ git+git://github.com/huggingface/transformers@59b5953d89544a66d73

On the Streamlit Sharing website you can simply link your repository, and your model will be live shortly!

Ethical considerations

The application introduced in this post is for entertainment purposes only! Applying the GPT2 model in other scenarios should be carefully considered. While certain domains were removed from the original training data, the GPT2 model was pre-trained on largely unfiltered content from the Internet, which contains biased and discriminating language. OpenAI’s model card points out these considerations:

Here are some secondary use cases we believe are likely:

  • Writing assistance: Grammar assistance, autocompletion (for normal prose or code)
  • Creative writing and art: exploring the generation of creative, fictional texts; aiding creation of poetry and other literary art.
  • Entertainment: Creation of games, chat bots, and amusing generations.

Out-of-scope use cases:

Because large-scale language models like GPT-2 do not distinguish fact from fiction, we don’t support use-cases that require the generated text to be true. Additionally, language models like GPT-2 reflect the biases inherent to the systems they were trained on, so we do not recommend that they be deployed into systems that interact with humans unless the deployers first carry out a study of biases relevant to the intended use-case. We found no statistically significant difference in gender, race, and religious bias probes between 774M and 1.5B, implying all versions of GPT-2 should be approached with similar levels of caution around use cases that are sensitive to biases around human attributes.

The example below shows how the model can have biased predictions (another example can be found here):

>>> from transformers import pipeline, set_seed
>>> generator = pipeline('text-generation', model='gpt2')
>>> set_seed(42)
>>> generator("The man worked as a", max_length=10, num_return_sequences=5)

[{'generated_text': 'The man worked as a waiter at a Japanese restaurant'},
 {'generated_text': 'The man worked as a bouncer and a boun'},
 {'generated_text': 'The man worked as a lawyer at the local firm'},
 {'generated_text': 'The man worked as a waiter in a cafe near'},
 {'generated_text': 'The man worked as a chef in a strip mall'}]

>>> set_seed(42)
>>> generator("The woman worked as a", max_length=10, num_return_sequences=5)

[{'generated_text': 'The woman worked as a waitress at a Japanese restaurant'},
 {'generated_text': 'The woman worked as a waitress at a local restaurant'},
 {'generated_text': 'The woman worked as a waitress at the local supermarket'},
 {'generated_text': 'The woman worked as a nurse in a health center'},
 {'generated_text': 'The woman worked as a maid in Daphne'}]

I urge the reader to carefully consider the application and use of such models in real-world scenarios. There are many resources (e.g. EML11, AINow 12) for learning about ethical ML.

Conclusion

Congratulations! Your application is now live!

By using open-source frameworks, we were able to quickly fine-tune a GPT2 model, prototype a fun application, and deploy it. The generated stories can further be improved by using more advanced pre-trained models, decoding methods, or even *structured language prediction.

References

Tornike Tsereteli
Tornike Tsereteli
/tʰɔrnikʼɛ tsʼɛrɛtʰɛli/

I am a PhD candidate in Computational Linguistics candidate at the University of Mannheim. I work on Natural Language Processing and Machine Learning.

Related