LlavaGuard / app.py
LukasHug's picture
Update app.py
422fbaa verified
import argparse
import datetime
import hashlib
import json
import logging
import os
import sys
import time
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import (
AutoProcessor,
AutoTokenizer,
Qwen2_5_VLForConditionalGeneration,
LlavaOnevisionForConditionalGeneration
)
from qwen_vl_utils import process_vision_info
from taxonomy import policy_v1
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("gradio_web_server.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("gradio_web_server")
# Constants
LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
# Get default model from environment variable or use a fallback
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf")
logger.info(f"Using model: {DEFAULT_MODEL}")
default_taxonomy = policy_v1
class SimpleConversation:
def __init__(self):
self.current_prompt = ""
self.current_image = None
self.current_response = None
self.skip_next = False
self.messages = [] # Add messages list to store conversation history
def set_prompt(self, prompt, image=None):
self.current_prompt = prompt
self.current_image = image
self.current_response = None
# Update messages when setting a new prompt
self.messages = [[prompt, None]]
def set_response(self, response):
self.current_response = response
# Update the last message's response when setting a response
if self.messages and len(self.messages) > 0:
self.messages[-1][-1] = response
def get_prompt(self):
if isinstance(self.current_prompt, tuple):
return self.current_prompt[0]
return self.current_prompt
def get_image(self, return_pil=False):
if self.current_image:
return [self.current_image]
if isinstance(self.current_prompt, tuple) and len(self.current_prompt) > 1:
if isinstance(self.current_prompt[1], Image.Image):
return [self.current_prompt[1]]
return None
def to_gradio_chatbot(self):
if not self.messages:
return []
ret = []
for msg in self.messages:
prompt = msg[0]
if isinstance(prompt, tuple) and len(prompt) > 0:
prompt = prompt[0]
if prompt and isinstance(prompt, str) and "<image>" in prompt:
prompt = prompt.replace("<image>", "")
ret.append([prompt, msg[1]])
return ret
def dict(self):
# Simplified serialization for logging
image_info = "[WITH_IMAGE]" if self.current_image is not None else "[NO_IMAGE]"
# Handle prompt which might be a tuple containing an image
prompt = self.get_prompt()
if isinstance(prompt, tuple):
prompt = prompt[0] # Just take the text part
# Create JSON-safe message representations
safe_messages = []
for msg in self.messages:
msg_prompt = msg[0]
# Handle tuple prompts that contain images
if isinstance(msg_prompt, tuple) and len(msg_prompt) > 0:
msg_prompt = msg_prompt[0] # Take just the text part
# Add the message with safe values
safe_messages.append([msg_prompt, "[RESPONSE]" if msg[1] else None])
return {
"prompt": prompt,
"image": image_info,
"response": self.current_response,
"messages": safe_messages
}
def copy(self):
new_conv = SimpleConversation()
new_conv.current_prompt = self.current_prompt
new_conv.current_image = self.current_image
new_conv.current_response = self.current_response
new_conv.skip_next = self.skip_next
new_conv.messages = self.messages.copy() if self.messages else []
return new_conv
default_conversation = SimpleConversation()
def wrap_taxonomy(text):
"""Wraps user input with taxonomy if not already present"""
if policy_v1 not in text:
return policy_v1 + "\n\n" + text
return text
# UI component states
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
os.makedirs(os.path.dirname(name), exist_ok=True)
return name
# Inference function
@spaces.GPU
def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
if model is None or processor is None:
return "Model not loaded. Please wait for model to initialize."
try:
# Check if it's a Qwen model
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
# Format for Qwen models
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
# Process input
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text_prompt],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Otherwise assume it's a LlavaGuard model
else:
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
},
]
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=text_prompt, images=image, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(
**inputs,
do_sample=temperature > 0,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_tokens,
)
# Decode
generated_ids_trimmed = generated_ids[0, inputs["input_ids"].shape[1]:]
response = processor.decode(
generated_ids_trimmed,
skip_special_tokens=True,
# clean_up_tokenization_spaces=False
)
print('response')
print(response)
return response.strip()
except Exception as e:
import traceback
error_msg = f"Error during inference: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
logger.error(error_msg)
return f"Error processing image. Please try again."
# Gradio UI functions
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
state = default_conversation.copy()
return state
def load_demo_refresh(request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}")
state = default_conversation.copy()
return state
def vote_last_response(state, vote_type, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": DEFAULT_MODEL,
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
def upvote_last_response(state, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, "upvote", request)
return ("",) + (disable_btn,) * 3
def downvote_last_response(state, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, "downvote", request)
return ("",) + (disable_btn,) * 3
def flag_last_response(state, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", request)
return ("",) + (disable_btn,) * 3
def regenerate(state, image_process_mode, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
if state.messages and len(state.messages) > 0:
state.messages[-1][-1] = None
if len(state.messages) > 1:
prev_human_msg = state.messages[-2]
if isinstance(prev_human_msg[0], tuple) and len(prev_human_msg[0]) >= 2:
# Handle image process mode for previous message if it's a tuple with image
new_msg = list(prev_human_msg)
if len(prev_human_msg[0]) >= 3:
new_msg[0] = (prev_human_msg[0][0], prev_human_msg[0][1], image_process_mode)
state.messages[-2] = new_msg
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def add_text(state, text, image, image_process_mode, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0 or image is None:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
text = wrap_taxonomy(text)
# Reset conversation for new image-based query
if image is not None:
state = default_conversation.copy()
# Set new prompt with image
prompt = text
if image is not None:
prompt = (text, image, image_process_mode)
state.set_prompt(prompt=prompt, image=image)
state.skip_next = False
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
def llava_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
start_tstamp = time.time()
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
# Get the prompt and images
prompt = state.get_prompt()
all_images = state.get_image(return_pil=True)
if not all_images:
if not state.messages:
state.messages = [["Error: No image provided", None]]
else:
state.messages[-1][-1] = "Error: No image provided"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
return
# Run inference
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
# Update the response in the conversation state
if not state.messages:
state.messages = [[prompt, output]]
else:
state.messages[-1][-1] = output
state.current_response = output
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"Generated response in {finish_tstamp - start_tstamp:.2f}s")
try:
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": DEFAULT_MODEL,
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"images": ['image'],
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
except Exception as e:
logger.error(f"Error writing log: {str(e)}")
# UI Components
title_markdown = f'Demo Model Version: {DEFAULT_MODEL}' + """
# LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
[[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)]
[[Code](https://github.com/ml-research/LlavaGuard)]
[[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)]
[[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)]
[[LavaGuard](https://arxiv.org/abs/2406.05113)]
"""
tos_markdown = """
### Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
"""
learn_more_markdown = """
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
"""
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
if not embed_mode:
gr.Markdown(title_markdown)
with gr.Row():
with gr.Column(scale=3):
# Model selector removed
imagebox = gr.Image(type="pil", label="Image", container=False)
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
label="Preprocess for non-square image", visible=False)
if cur_dir is None:
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(examples=[
[f"{cur_dir}/examples/image{i}.png"] for i in range(1, 6) if
os.path.exists(f"{cur_dir}/examples/image{i}.png")
], inputs=imagebox)
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P")
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
label="Max output tokens")
with gr.Accordion("Safety Risk Taxonomy", open=False):
taxonomy_textbox = gr.Textbox(
label="Safety Risk Taxonomy",
show_label=True,
placeholder="Enter your safety policy here",
value=default_taxonomy,
lines=20)
with gr.Column(scale=8):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="LLavaGuard Safety Assessment",
height=650,
layout="panel",
)
with gr.Row():
with gr.Column(scale=8):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter your message here",
container=True,
value=default_taxonomy,
lines=3,
)
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
if not embed_mode:
gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
url_params = gr.JSON(visible=False)
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
upvote_last_response,
[state],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
downvote_btn.click(
downvote_last_response,
[state],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
flag_btn.click(
flag_last_response,
[state],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
regenerate_btn.click(
regenerate,
[state, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
llava_bot,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
clear_btn.click(
clear_history,
None,
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
)
textbox.submit(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
).then(
llava_bot,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
submit_btn.click(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
llava_bot,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
demo.load(
load_demo_refresh,
None,
[state],
queue=False
)
return demo
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--concurrency-count", type=int, default=5)
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()
# Create log directory if it doesn't exist
os.makedirs(LOGDIR, exist_ok=True)
# GPU Check
if torch.cuda.is_available():
logger.info(f"CUDA available with {torch.cuda.device_count()} devices")
else:
logger.warning("CUDA not available! Models will run on CPU which may be very slow.")
# Hugging Face token handling
api_key = os.getenv("token")
if api_key:
from huggingface_hub import login
login(token=api_key)
logger.info("Logged in to Hugging Face Hub")
# Load model at startup
model_path = DEFAULT_MODEL
logger.info(f"Loading model: {model_path}")
# Check if it's a Qwen model
if "qwenguard" in model_path.lower():
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_path)
tokenizer = processor.tokenizer
# Otherwise assume it's a LlavaGuard model
else:
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
context_len = getattr(model.config, "max_position_embeddings", 8048)
logger.info(f"Model {model_path} loaded successfully")
demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
demo.queue(
status_update_rate=10,
api_open=False
).launch(
server_name=args.host,
server_port=args.port,
share=args.share
)