Spaces:
Build error
Build error
| import tempfile | |
| from share_btn import community_icon_html, loading_icon_html, share_js, save_js | |
| import huggingface_hub | |
| import gradio as gr | |
| from fromage import utils | |
| from fromage import models | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" | |
| css = """ | |
| #share-btn-container { | |
| display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; | |
| margin-top: 3px; | |
| margin-left: auto; | |
| flex: unset; | |
| } | |
| #share-btn { | |
| all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; | |
| } | |
| #share-btn * { | |
| all: unset; | |
| } | |
| #share-btn-container div:nth-child(-n+2){ | |
| width: auto !important; | |
| min-height: 0px !important; | |
| } | |
| #share-btn-container .wrap { | |
| display: none !important; | |
| } | |
| #chatbot { min-height: 300px; } | |
| #save-btn { | |
| background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); | |
| } | |
| #save-btn:hover { | |
| background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0)); | |
| } | |
| #share-btn-2 { | |
| background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); | |
| } | |
| #share-btn-2:hover { | |
| background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0)); | |
| } | |
| """ | |
| examples = [ | |
| 'examples/sparrow.png', | |
| 'examples/beaver.png', | |
| 'examples/couch.png', | |
| 'examples/guac.png', | |
| 'examples/scraped_knee.png' | |
| ] | |
| # Download model from HF Hub. | |
| ckpt_path = huggingface_hub.hf_hub_download( | |
| repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar') | |
| args_path = huggingface_hub.hf_hub_download( | |
| repo_id='jykoh/fromage', filename='model_args.json') | |
| model = models.load_fromage('./', args_path, ckpt_path) | |
| def upload_image(state, image_input): | |
| conversation = state[0] | |
| chat_history = state[1] | |
| input_image = Image.open(image_input.name).resize( | |
| (224, 224)).convert('RGB') | |
| input_image.save(image_input.name) # Overwrite with smaller image. | |
| conversation += [(f'<img src="/file={image_input.name}" style="display: inline-block;">', "")] | |
| return [conversation, chat_history + [input_image, ""]], conversation | |
| def reset(): | |
| return [[], []], [] | |
| def reset_last(state): | |
| conversation = state[0][:-1] | |
| chat_history = state[1][:-2] | |
| return [conversation, chat_history], conversation | |
| def save_image_to_local(image: Image.Image): | |
| # TODO(jykoh): Update so the url path is used, to prevent repeat saving. | |
| filename = next(tempfile._get_candidate_names()) + '.png' | |
| image.save(filename) | |
| return filename | |
| def generate_for_prompt(input_text, state, ret_scale_factor, max_num_rets, num_words, temperature): | |
| # Ignore empty inputs. | |
| if len(input_text) == 0: | |
| return state, state[0], gr.update(visible=True) | |
| input_prompt = 'Q: ' + input_text + '\nA:' | |
| conversation = state[0] | |
| chat_history = state[1] | |
| print('Generating for', chat_history, flush=True) | |
| # If an image was uploaded, prepend it to the model. | |
| model_inputs = chat_history | |
| model_inputs.append(input_prompt) | |
| top_p = 1.0 | |
| if temperature != 0.0: | |
| top_p = 0.95 | |
| print('Running model.generate_for_images_and_texts with', | |
| model_inputs, flush=True) | |
| model_outputs = model.generate_for_images_and_texts(model_inputs, | |
| num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p, | |
| temperature=temperature, max_num_rets=max_num_rets) | |
| print('model_outputs', model_outputs, flush=True) | |
| im_names = [] | |
| response = '' | |
| text_outputs = [] | |
| for output_i, output in enumerate(model_outputs): | |
| if type(output) == str: | |
| if output_i > 0: | |
| response += '<br/>' | |
| text_outputs.append(output) | |
| response += output | |
| if len(model_outputs) > 1: | |
| response += '<br/>' | |
| elif type(output) == list: | |
| for image in output: | |
| filename = save_image_to_local(image) | |
| response += f'<img src="/file={filename}" style="display: inline-block;">' | |
| elif type(output) == Image.Image: | |
| filename = save_image_to_local(output) | |
| response += f'<img src="/file={filename}" style="display: inline-block;">' | |
| # TODO(jykoh): Persist image inputs. | |
| chat_history = model_inputs + \ | |
| [' '.join([s for s in model_outputs if type(s) == str]) + '\n'] | |
| # Remove [RET] from outputs. | |
| conversation.append((input_text, response.replace('[RET]', ''))) | |
| # Set input image to None. | |
| print('state', state, flush=True) | |
| print('updated state', [conversation, chat_history], flush=True) | |
| return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True) | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML(""" | |
| <h1>🧀 FROMAGe</h1> | |
| <p>This is the official Gradio demo for the FROMAGe model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.</p> | |
| <strong>Paper:</strong> <a href="https://arxiv.org/abs/2301.13823" target="_blank">Grounding Language Models to Images for Multimodal Generation</a> | |
| <br/> | |
| <strong>Project Website:</strong> <a href="https://jykoh.com/fromage" target="_blank">FROMAGe Website</a> | |
| <br/> | |
| <strong>Code and Models:</strong> <a href="https://github.com/kohjingyu/fromage" target="_blank">GitHub</a> | |
| <br/> | |
| <br/> | |
| <strong>Tips:</strong> | |
| <ul> | |
| <li>Start by inputting either image or text prompts (or both) and chat with FROMAGe to get image-and-text replies.</li> | |
| <li>Tweak the level of sensitivity to images and text using the parameters on the right.</li> | |
| <li>Check out cool conversations in the examples or community tab for inspiration and share your own!</li> | |
| <li>For faster inference without waiting in queue, you may duplicate the space and use your own GPU: <a href="https://huggingface.co/spaces/jykoh/fromage?duplicate=true"><img style="display: inline-block; margin-top: 0em; margin-bottom: 0em" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></li> | |
| </ul> | |
| """) | |
| gr_state = gr.State([[], []]) # conversation, chat_history | |
| with gr.Row(): | |
| with gr.Column(scale=0.7, min_width=500): | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(elem_id="chatbot", label="🧀 FROMAGe Chatbot") | |
| with gr.Row(): | |
| image_btn = gr.UploadButton("🖼️ Upload Image", file_types=["image"]) | |
| text_input = gr.Textbox(label="Message", placeholder="Type a message") | |
| with gr.Column(): | |
| submit_btn = gr.Button( | |
| "Submit", interactive=True, variant="primary") | |
| clear_last_btn = gr.Button("Undo") | |
| clear_btn = gr.Button("Reset All") | |
| with gr.Row(visible=False) as save_group: | |
| save_button = gr.Button("💾 Save Conversation as .png", elem_id="save-btn") | |
| with gr.Row(visible=False) as share_group: | |
| share_button = gr.Button("🤗 Share to Community", elem_id="share-btn-2") | |
| with gr.Column(scale=0.3, min_width=200): | |
| ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, | |
| label="Frequency multiplier for returning images (higher means more frequent)") | |
| max_ret_images = gr.Number( | |
| minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return") | |
| gr_max_len = gr.Slider(minimum=1, maximum=64, value=32, | |
| step=1, interactive=True, label="Max # of words") | |
| gr_temperature = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.0, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)") | |
| gallery = gr.Gallery( | |
| value=[Image.open(e) for e in examples], label="Example Conversations", show_label=True, elem_id="gallery", | |
| ).style(grid=[2], height="auto") | |
| text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, | |
| max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) | |
| text_input.submit(lambda: "", None, text_input) # Reset chatbox. | |
| submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor, | |
| max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) | |
| submit_btn.click(lambda: "", None, text_input) # Reset chatbox. | |
| image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot]) | |
| clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot]) | |
| clear_btn.click(reset, [], [gr_state, chatbot]) | |
| share_button.click(None, [], [], _js=share_js) | |
| save_button.click(None, [], [], _js=save_js) | |
| demo.queue(concurrency_count=1, api_open=False, max_size=16) | |
| demo.launch(debug=True, server_name="0.0.0.0") | |
| # demo.launch(debug=True, server_name="127.0.0.1") | |