import os import string import copy import gradio as gr import PIL.Image import torch from transformers import BitsAndBytesConfig, pipeline import re import time import random DESCRIPTION = "# LLaVA 🌋💪 - Now with Arnold Mode!" model_id = "llava-hf/llava-1.5-7b-hf" quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 ) pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config}) def extract_response_pairs(text): turns = re.split(r'(USER:|ASSISTANT:)', text)[1:] turns = [turn.strip() for turn in turns if turn.strip()] conv_list = [] for i in range(0, len(turns[1::2]), 2): if i + 1 < len(turns[1::2]): conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")]) return conv_list def add_text(history, text): history = history + [[text, None]] return history, text def arnold_speak(text): arnold_phrases = [ "Come with me if you want to lift!", "I'll be back... after my protein shake.", "Hasta la vista, baby weight!", "Get to da choppa... I mean, da squat rack!", "You lack discipline! But don't worry, I'm here to pump you up!" ] text = text.replace(".", "!") # More enthusiastic punctuation text = text.replace("gym", "iron paradise") text = text.replace("exercise", "pump iron") text = text.replace("workout", "sculpt your physique") # Add random Arnold phrase to the end text += " " + random.choice(arnold_phrases) return text def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p): outputs = pipe(images=image, prompt=prompt, generate_kwargs={"temperature": temperature, "length_penalty": length_penalty, "repetition_penalty": repetition_penalty, "max_length": max_length, "min_length": min_length, "top_p": top_p}) inference_output = outputs[0]["generated_text"] return inference_output def bot(history_chat, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode): if text_input == "": gr.Warning("Please input text") if image is None: gr.Warning("Please input image or wait for image to be uploaded before clicking submit.") chat_history = " ".join([item for sublist in history_chat for item in sublist]) # Flatten history if arnold_mode: system_prompt = "You are Arnold Schwarzenegger, the famous bodybuilder and actor. Respond in his iconic style, using his catchphrases and focusing on fitness and motivation." else: system_prompt = "You are a helpful AI assistant. Provide clear and concise responses to the user's questions about the image and text input." chat_history = f"{system_prompt}\n{chat_history}\nUSER: \n{text_input}\nASSISTANT:" inference_result = infer(image, chat_history, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p) chat_val = extract_response_pairs(inference_result) chat_state_list = copy.deepcopy(chat_val) chat_state_list[-1][1] = "" # empty last response response = chat_val[-1][1] if arnold_mode: response = arnold_speak(response) for character in response: chat_state_list[-1][1] += character time.sleep(0.05) yield chat_state_list css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.Markdown(DESCRIPTION) gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚡️ See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""") chatbot = gr.Chatbot(label="Chat", show_label=False) gr.Markdown("Input image and text and start chatting 👇") with gr.Row(): image = gr.Image(type="pil") text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False) history_chat = gr.State(value=[]) arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode", value=False) with gr.Accordion(label="Advanced settings", open=False): temperature = gr.Slider(label="Temperature", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=1.0) length_penalty = gr.Slider(label="Length Penalty", info="Set to larger for longer sequence, used with beam search.", minimum=-1.0, maximum=2.0, step=0.2, value=1.0) repetition_penalty = gr.Slider(label="Repetition Penalty", info="Larger value prevents repetition.", minimum=1.0, maximum=5.0, step=0.5, value=1.5) max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, step=1, value=200) min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1) top_p = gr.Slider(label="Top P", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=0.9) chat_inputs = [chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, history_chat, arnold_mode] with gr.Row(): clear_chat_button = gr.Button("Clear") cancel_btn = gr.Button("Stop Generation") chat_button = gr.Button("Submit", variant="primary") chat_event1 = chat_button.click(add_text, [chatbot, text_input], [chatbot, text_input]).then( bot, chat_inputs, chatbot ) chat_event2 = text_input.submit(add_text, [chatbot, text_input], [chatbot, text_input]).then( bot, chat_inputs, chatbot ) clear_chat_button.click(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False, api_name="clear") image.change(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False) cancel_btn.click(None, [], [], cancels=[chat_event1, chat_event2]) examples = [ ["./examples/baklava.png", "How to make this pastry?"], ["./examples/bee.png", "Describe this image."] ] gr.Examples(examples=examples, inputs=[image, text_input]) if __name__ == "__main__": demo.queue(max_size=10).launch(debug=True)