codewithdark's picture
Update app.py
d86d806 verified
raw
history blame
1.46 kB
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
# Load the local model
model_name = "codewithdark/latent-recurrent-depth-lm"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval() # Set to evaluation mode
# Define inference function
def chat_with_model(input_text, model_choice):
if model_choice == "Latent Recurrent Depth LM":
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output = model.generate(input_ids, max_length=512)
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response
return "Model not available yet!"
# Create Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# 🤖 Chat with Latent Recurrent Depth LM")
model_choice = gr.Radio(
["Latent Recurrent Depth LM"], # Add more models if needed
label="Select Model",
value="Latent Recurrent Depth LM"
)
text_input = gr.Textbox(label="Enter your message")
submit_button = gr.Button("Generate Response")
output_text = gr.Textbox(label="Model Response")
submit_button.click(fn=chat_with_model, inputs=[text_input, model_choice], outputs=output_text)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()