charliebaby2023's picture
Update app.py
8d3f8cd verified
raw
history blame
4.13 kB
import gradio as gr
from random import randint
from all_models import models_groups
from datetime import datetime
kii = "mohawk femboy racecar driver"
# Predefined arrays of models
def get_current_time():
now = datetime.now()
current_time = now.strftime("%Y-%m-%d %H:%M:%S")
return f'{kii} {current_time}'
def load_fn(models):
models_load = {}
for model in models:
if model not in models_load:
try:
m = gr.load(f'models/{model}') # Adjust `gr.load` as needed
except Exception as error:
m = gr.Interface(lambda txt: None, ['text'], ['image'])
models_load[model] = m
return models_load
models_load = load_fn(models)
def extend_choices(choices, num_models):
return choices + (num_models - len(choices)) * ['NA']
def update_imgbox(choices, num_models):
choices_plus = extend_choices(choices, num_models)
return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
def gen_fn(model_str, prompt, tallies):
if model_str == 'NA':
return None, tallies
noise = str(randint(0, 9999))
combined_prompt = f'{prompt} {model_str} {noise}'
print(f"Generating with prompt: {combined_prompt}") # Debug line
try:
result = models_load.get(model_str, lambda txt: None)(combined_prompt)
if result is not None:
tallies[model_str] += 1
return result, tallies
except Exception as e:
print(f"Error generating for {model_str}: {e}")
return None, tallies
def make_me():
with gr.Row():
# Input elements
model_group_selector = gr.Dropdown(
choices=list(model_groups.keys()), label="Select Model Group", value="Group A"
)
txt_input = gr.Textbox(label='Your prompt:', lines=2, value=kii)
gen_button = gr.Button('Generate images', elem_id="generate-btn")
stop_button = gr.Button('Stop', variant='secondary', interactive=False)
gr.HTML("""
<div style="text-align: center; max-width: 100%; margin: 0 auto;">
<body></body>
</div>
""")
with gr.Row():
# Output elements
output = gr.State([])
tally_boxes = gr.State({})
output_display = gr.Column()
with output_display:
result_images = []
tally_counters = []
def update_outputs(group_name):
selected_models = model_groups[group_name]
result_images.clear()
tally_counters.clear()
for model in selected_models:
result_images.append(gr.Image(label=model, width=170, height=170))
tally_counters.append(gr.Textbox(value="0", label=f"Tally for {model}", interactive=False))
return result_images, tally_counters, {model: 0 for model in selected_models}
model_group_selector.change(
update_outputs, [model_group_selector], [output, tally_boxes]
)
def generate_images(prompt, outputs, tallies):
for idx, model_element in enumerate(outputs):
model_str = list(tallies.keys())[idx]
result, tallies = gen_fn(model_str, prompt, tallies)
model_element.update(value=result)
for idx, tally_box in enumerate(tally_counters):
tally_box.update(value=str(tallies[list(tallies.keys())[idx]]))
return tallies
gen_button.click(
generate_images,
inputs=[txt_input, output, tally_boxes],
outputs=[tally_boxes],
)
js_code = """
<script>
const originalScroll = window.scrollTo;
const originalShowToast = gradio.Toast.show;
gradio.Toast.show = function() {
originalShowToast.apply(this, arguments);
window.scrollTo = function() {};
};
setTimeout(() => {
window.scrollTo = originalScroll;
}, 3); // Restore scroll function after 3 seconds
</script>
"""
with gr.Blocks() as demo:
gr.Markdown("<div></div>")
make_me()
gr.Markdown(js_code)
demo.queue(concurrency_count=50)
demo.launch()