File size: 5,092 Bytes
44466c7 7d9af7f 44466c7 7d9af7f 44466c7 7d9af7f 44466c7 7d9af7f 44466c7 7d9af7f 525f2d6 7d9af7f 1bf893a 7d9af7f 525f2d6 7d9af7f 525f2d6 7bcf67d 525f2d6 7d9af7f 525f2d6 7d9af7f 525f2d6 7d9af7f 7bcf67d 525f2d6 7d9af7f dc084c4 525f2d6 7d9af7f 525f2d6 7d9af7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import streamlit as st
import pandas as pd
from datasets import load_dataset
from random import sample
from utils.metric import Regard
from utils.model import gpt2
import os
# Set up the Streamlit interface
st.set_page_config(page_title="Gender Bias Analysis", page_icon="π", layout="wide")
st.title('Gender Bias Analysis in Text Generation')
# Password protection function
def check_password():
def password_entered():
if password_input == os.getenv('PASSWORD'):
st.session_state['password_correct'] = True
else:
st.error("Incorrect Password, please try again.")
if 'password_correct' not in st.session_state:
st.session_state['password_correct'] = False
if not st.session_state['password_correct']:
password_input = st.text_input("Enter Password:", type="password")
st.button("Submit", on_click=password_entered)
# Data loading function
def load_data():
if 'bold' not in st.session_state:
st.session_state['bold'] = load_dataset("AlexaAI/bold", split="train")
# Sampling function
def sample_data(data_size):
st.session_state['female_bold'] = sample(
[p for p in st.session_state['bold'] if p['category'] == 'American_actresses'], data_size)
st.session_state['male_bold'] = sample(
[p for p in st.session_state['bold'] if p['category'] == 'American_actors'], data_size)
# Text generation function
def generate_text():
GPT2 = gpt2()
st.session_state['male_prompts'] = [p['prompts'][0] for p in st.session_state['male_bold']]
st.session_state['female_prompts'] = [p['prompts'][0] for p in st.session_state['female_bold']]
progress_bar = st.progress(0)
st.write('Generating text for male prompts...')
male_generation = GPT2.text_generation(st.session_state['male_prompts'], pad_token_id=50256, max_length=50,
do_sample=False, truncation=True)
st.session_state['male_continuations'] = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
zip(male_generation, st.session_state['male_prompts'])]
progress_bar.progress(50)
st.write('Generating text for female prompts...')
female_generation = GPT2.text_generation(st.session_state['female_prompts'], pad_token_id=50256,
max_length=50, do_sample=False, truncation=True)
st.session_state['female_continuations'] = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
zip(female_generation, st.session_state['female_prompts'])]
progress_bar.progress(100)
st.write('Text generation completed.')
# Display data samples function
def display_samples():
st.write("### Male Data Samples")
samples_df = pd.DataFrame({
'Male Prompt': st.session_state['male_prompts'],
'Male Continuation': st.session_state['male_continuations'],
})
st.dataframe(samples_df)
st.write("### Female Data Samples")
samples_df = pd.DataFrame({
'Female Prompt': st.session_state['female_prompts'],
'Female Continuation': st.session_state['female_continuations']
})
st.dataframe(samples_df)
# Evaluate regard function
def evaluate_regard():
regard = Regard("compare")
st.write('Computing regard results to compare male and female continuations...')
with st.spinner('Computing regard results...'):
regard_results = regard.compute(data=st.session_state['male_continuations'],
references=st.session_state['female_continuations'])
st.write('**Raw Regard Results:**')
st.json(regard_results)
regard_results_avg = regard.compute(data=st.session_state['male_continuations'],
references=st.session_state['female_continuations'],
aggregation='average')
st.write('**Average Regard Results:**')
st.json(regard_results_avg)
# Main app logic
if not st.session_state.get('password_correct', False):
check_password()
else:
st.sidebar.success("Password Verified. Proceed with the demo.")
load_data()
st.subheader('Step 1: Set Data Size')
data_size = st.slider('Select number of samples per category:', min_value=1, max_value=50,
value=st.session_state.get('data_size', 10))
st.session_state['data_size'] = data_size
if st.button('Show Data'):
sample_data(data_size)
st.write(f'Sampled {data_size} female and male American actors.')
display_samples()
if st.session_state.get('female_bold') and st.session_state.get('male_bold'):
st.subheader('Step 2: Generate Text')
if st.button('Generate Text'):
generate_text()
if st.session_state.get('male_continuations') and st.session_state.get('female_continuations'):
st.subheader('Step 3: Evaluate')
display_samples()
if st.button('Evaluate'):
evaluate_regard()
|