File size: 961 Bytes
2998bd1
71976e8
 
 
2998bd1
902a9ef
2998bd1
902a9ef
71976e8
 
2998bd1
7a047e7
 
 
 
 
2998bd1
902a9ef
7a047e7
 
 
 
71976e8
 
7a047e7
2998bd1
7a047e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava.model.builder import load_pretrained_model  # Import LLaVA model builder

# Model name
model_name = "MONAI/Llama3-VILA-M3-8B"

# Load LLaVA model
tokenizer, model, _ = load_pretrained_model(model_path=model_name, model_base=None, device="cuda" if torch.cuda.is_available() else "cpu")

def generate_response(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        output = model.generate(**inputs, max_length=200)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Gradio Interface
iface = gr.Interface(
    fn=generate_response,
    inputs=gr.Textbox(lines=2, placeholder="Enter your prompt..."),
    outputs="text",
    title="LLaVA Llama3-VILA-M3-8B Chatbot",
    description="A chatbot powered by LLaVA and Llama3-VILA-M3-8B",
)

iface.launch()