File size: 2,276 Bytes
65d5ebe
 
 
 
 
 
 
 
d0f4aff
5968a97
d0f4aff
5968a97
d0f4aff
0ebe852
d0f4aff
 
65d5ebe
d0f4aff
 
 
 
 
 
 
5968a97
d0f4aff
 
 
 
5968a97
d0f4aff
 
 
 
 
5968a97
d0f4aff
5968a97
d0f4aff
 
5968a97
 
 
 
d0f4aff
5968a97
 
 
 
 
 
d0f4aff
5968a97
d0f4aff
 
 
 
 
 
5968a97
 
d0f4aff
5968a97
 
d0f4aff
 
 
 
5968a97
0ebe852
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import subprocess

def install(package):
    subprocess.check_call([os.sys.executable, "-m", "pip", "install", package])

install("transformers")

import os
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer

# Global cache for loaded models
model_cache = {}

# Load a model with progress bar
def load_model(model_name, progress=gr.Progress(track_tqdm=False)):
    if model_name not in model_cache:
        token = os.getenv("HF_TOKEN")
        progress(0, desc="Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
        progress(0.5, desc="Loading model...")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            use_auth_token=token
        )
        model_cache[model_name] = (tokenizer, model)
        progress(1, desc="Model ready.")
        return f"{model_name} loaded and ready!"
    else:
        return f"{model_name} already loaded."

# Inference function using GPU
@spaces.GPU
def generate_text(model_name, prompt):
    tokenizer, model = model_cache[model_name]
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=256)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Available models
model_choices = [
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    "meta-llama/Llama-3.2-3B-Instruct",
    "google/gemma-7b"
]

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Clinical Text Analysis with LLMs (LLaMA, DeepSeek, Gemma)")
    
    with gr.Row():
        model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
        model_status = gr.Textbox(label="Model Status", interactive=False)

    input_text = gr.Textbox(label="Input Clinical Text")
    output_text = gr.Textbox(label="Generated Output")
    
    analyze_button = gr.Button("Analyze")

    # Load model when changed
    model_selector.change(fn=load_model, inputs=model_selector, outputs=model_status)
    
    # Generate output
    analyze_button.click(fn=generate_text, inputs=[model_selector, input_text], outputs=output_text)

demo.launch()