File size: 5,092 Bytes
44466c7
 
 
 
 
 
 
 
 
7d9af7f
44466c7
 
7d9af7f
44466c7
 
 
 
 
 
 
7d9af7f
 
44466c7
7d9af7f
 
 
44466c7
7d9af7f
 
 
 
525f2d6
7d9af7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bf893a
 
 
 
7d9af7f
525f2d6
 
 
7d9af7f
525f2d6
7bcf67d
525f2d6
7d9af7f
525f2d6
7d9af7f
525f2d6
7d9af7f
7bcf67d
525f2d6
7d9af7f
dc084c4
525f2d6
7d9af7f
 
525f2d6
7d9af7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
import pandas as pd
from datasets import load_dataset
from random import sample
from utils.metric import Regard
from utils.model import gpt2
import os

# Set up the Streamlit interface
st.set_page_config(page_title="Gender Bias Analysis", page_icon="πŸ”", layout="wide")
st.title('Gender Bias Analysis in Text Generation')

# Password protection function
def check_password():
    def password_entered():
        if password_input == os.getenv('PASSWORD'):
            st.session_state['password_correct'] = True
        else:
            st.error("Incorrect Password, please try again.")

    if 'password_correct' not in st.session_state:
        st.session_state['password_correct'] = False

    if not st.session_state['password_correct']:
        password_input = st.text_input("Enter Password:", type="password")
        st.button("Submit", on_click=password_entered)

# Data loading function
def load_data():
    if 'bold' not in st.session_state:
        st.session_state['bold'] = load_dataset("AlexaAI/bold", split="train")

# Sampling function
def sample_data(data_size):
    st.session_state['female_bold'] = sample(
        [p for p in st.session_state['bold'] if p['category'] == 'American_actresses'], data_size)
    st.session_state['male_bold'] = sample(
        [p for p in st.session_state['bold'] if p['category'] == 'American_actors'], data_size)

# Text generation function
def generate_text():
    GPT2 = gpt2()
    st.session_state['male_prompts'] = [p['prompts'][0] for p in st.session_state['male_bold']]
    st.session_state['female_prompts'] = [p['prompts'][0] for p in st.session_state['female_bold']]

    progress_bar = st.progress(0)
    st.write('Generating text for male prompts...')
    male_generation = GPT2.text_generation(st.session_state['male_prompts'], pad_token_id=50256, max_length=50,
                                           do_sample=False, truncation=True)
    st.session_state['male_continuations'] = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
                                              zip(male_generation, st.session_state['male_prompts'])]

    progress_bar.progress(50)

    st.write('Generating text for female prompts...')
    female_generation = GPT2.text_generation(st.session_state['female_prompts'], pad_token_id=50256,
                                             max_length=50, do_sample=False, truncation=True)
    st.session_state['female_continuations'] = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
                                                zip(female_generation, st.session_state['female_prompts'])]

    progress_bar.progress(100)
    st.write('Text generation completed.')

# Display data samples function
def display_samples():
    st.write("### Male Data Samples")
    samples_df = pd.DataFrame({
        'Male Prompt': st.session_state['male_prompts'],
        'Male Continuation': st.session_state['male_continuations'],
    })
    st.dataframe(samples_df)

    st.write("### Female Data Samples")
    samples_df = pd.DataFrame({
        'Female Prompt': st.session_state['female_prompts'],
        'Female Continuation': st.session_state['female_continuations']
    })
    st.dataframe(samples_df)

# Evaluate regard function
def evaluate_regard():
    regard = Regard("compare")
    st.write('Computing regard results to compare male and female continuations...')

    with st.spinner('Computing regard results...'):
        regard_results = regard.compute(data=st.session_state['male_continuations'],
                                        references=st.session_state['female_continuations'])
        st.write('**Raw Regard Results:**')
        st.json(regard_results)

        regard_results_avg = regard.compute(data=st.session_state['male_continuations'],
                                            references=st.session_state['female_continuations'],
                                            aggregation='average')
        st.write('**Average Regard Results:**')
        st.json(regard_results_avg)

# Main app logic
if not st.session_state.get('password_correct', False):
    check_password()
else:
    st.sidebar.success("Password Verified. Proceed with the demo.")
    load_data()

    st.subheader('Step 1: Set Data Size')
    data_size = st.slider('Select number of samples per category:', min_value=1, max_value=50,
                          value=st.session_state.get('data_size', 10))
    st.session_state['data_size'] = data_size

    if st.button('Show Data'):
        sample_data(data_size)
        st.write(f'Sampled {data_size} female and male American actors.')
        display_samples()

    if st.session_state.get('female_bold') and st.session_state.get('male_bold'):
        st.subheader('Step 2: Generate Text')
        if st.button('Generate Text'):
            generate_text()

    if st.session_state.get('male_continuations') and st.session_state.get('female_continuations'):
        st.subheader('Step 3: Evaluate')
        display_samples()
        if st.button('Evaluate'):
            evaluate_regard()