File size: 1,454 Bytes
a4e4a8e
 
0c51978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e4a8e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr

# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("Tianlin668/MentalBART")
model = AutoModelForSeq2SeqLM.from_pretrained("Tianlin668/MentalBART")

# Define the inference function
def inference(prompt):

    # Encode the input prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    # Generate a response
    # You can adjust max_length and num_beams according to your needs
    output = model.generate(input_ids, max_length=150, num_beams=5, early_stopping=True)

    # Decode the generated output
    result = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Return the generated response from the model
    return result

with gr.Blocks() as demo:
    gr.Markdown("<center><h1> This is your virtual assistant - MentalBART model </h1></center>")
    gr.Markdown("<center><h2> This assistant will try to diagnose for depression and provide reasoning for it. </h2></center>")

    # Input boxes for the prompt and Hugging Face token
    prompt = gr.Textbox(label="Prompt", lines=3, max_lines=5)
    
    # Button for generating response
    generate_btn = gr.Button("Generate Response")
    output = gr.Markdown("Response will appear here.")
    
    # Trigger the inference function on button click
    generate_btn.click(fn=inference, inputs=[prompt], outputs=[output])
    

if __name__ == "__main__":
    demo.launch()