import gradio as gr import torch from transformers import LlavaProcessor, LlavaForConditionalGeneration from PIL import Image # 加载模型 model_id = "llava-hf/llava-v1.5-7b" processor = LlavaProcessor.from_pretrained(model_id) model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16) def llava_infer(image, text): if image is None or text.strip() == "": return "请提供图片和文本输入" # 处理输入 inputs = processor(text=text, images=image, return_tensors="pt").to("cuda") # 生成输出 with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=100) result = processor.batch_decode(output, skip_special_tokens=True)[0] return result # 创建 Gradio 界面 iface = gr.Interface( fn=llava_infer, inputs=[gr.Image(type="pil"), gr.Textbox(placeholder="输入文本...")], outputs="text", title="LLaVA Web UI", description="上传图片并输入文本,LLaVA 将返回回答" ) iface.launch()