File size: 1,462 Bytes
59c5051
77153c2
d86d806
59c5051
d86d806
 
 
 
77153c2
d86d806
59c5051
d86d806
 
 
 
 
 
 
 
 
59c5051
d86d806
 
 
77153c2
d86d806
 
 
 
 
 
 
 
 
59c5051
d86d806
59c5051
d86d806
59c5051
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer

# Load the local model
model_name = "codewithdark/latent-recurrent-depth-lm"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()  # Set to evaluation mode

# Define inference function
def chat_with_model(input_text, model_choice):
    if model_choice == "Latent Recurrent Depth LM":
        input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
        with torch.no_grad():
            output = model.generate(input_ids, max_length=512)
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        return response
    return "Model not available yet!"

# Create Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🤖 Chat with Latent Recurrent Depth LM")
    
    model_choice = gr.Radio(
        ["Latent Recurrent Depth LM"],  # Add more models if needed
        label="Select Model",
        value="Latent Recurrent Depth LM"
    )
    
    text_input = gr.Textbox(label="Enter your message")
    submit_button = gr.Button("Generate Response")
    output_text = gr.Textbox(label="Model Response")

    submit_button.click(fn=chat_with_model, inputs=[text_input, model_choice], outputs=output_text)

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()