Spaces:
Runtime error
Runtime error
File size: 961 Bytes
2998bd1 71976e8 2998bd1 902a9ef 2998bd1 902a9ef 71976e8 2998bd1 7a047e7 2998bd1 902a9ef 7a047e7 71976e8 7a047e7 2998bd1 7a047e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava.model.builder import load_pretrained_model # Import LLaVA model builder
# Model name
model_name = "MONAI/Llama3-VILA-M3-8B"
# Load LLaVA model
tokenizer, model, _ = load_pretrained_model(model_path=model_name, model_base=None, device="cuda" if torch.cuda.is_available() else "cpu")
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
output = model.generate(**inputs, max_length=200)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Gradio Interface
iface = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt..."),
outputs="text",
title="LLaVA Llama3-VILA-M3-8B Chatbot",
description="A chatbot powered by LLaVA and Llama3-VILA-M3-8B",
)
iface.launch()
|