File size: 2,784 Bytes
8037081
a9bbd1e
8037081
 
a9bbd1e
8037081
a9bbd1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8037081
a9bbd1e
 
 
 
 
 
 
 
 
8037081
a9bbd1e
 
 
 
 
 
 
8037081
a9bbd1e
 
 
 
 
 
8037081
a9bbd1e
 
 
8037081
a9bbd1e
 
 
 
 
8037081
a9bbd1e
8037081
a9bbd1e
 
 
 
 
 
 
 
 
 
8037081
a9bbd1e
 
8037081
a9bbd1e
 
 
 
 
8037081
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import os
import subprocess
import uuid
import pandas as pd

# Helper to save SMILES string to a temporary CSV
def save_smiles_to_csv(smiles: str, temp_dir="temp_inputs"):
    os.makedirs(temp_dir, exist_ok=True)
    file_path = os.path.join(temp_dir, f"{uuid.uuid4().hex}.csv")
    df = pd.DataFrame({"smiles": [smiles]})
    df.to_csv(file_path, index=False)
    return file_path

# Core prediction logic
def predict(model_version, dataset_name, input_type, file=None, smiles=None):
    # Set paths
    if model_version == "Vanilla Chemprop":
        model_dir = "chemprop"
        model_path = f"model_weight/{dataset_name}/best_unbalanced.pt"
        script_path = "chemprop/chemprop/cli/predict.py"
    else:
        model_dir = "chemprop_update"
        model_path = f"model_weight/{dataset_name}/best_bert_fusion.pt"
        script_path = "chemprop_update/chemprop/cli/predict.py"

    # Prepare input file
    if input_type == "Upload CSV":
        if file is None:
            return "Please upload a CSV file."
        input_path = file.name
    else:
        if not smiles:
            return "Please enter a SMILES string."
        input_path = save_smiles_to_csv(smiles)

    # Run prediction command
    cmd = [
        "python", script_path,
        "--test-path", input_path,
        "--model-paths", model_path,
        "--smiles-columns", "smiles"
    ]
    try:
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            return f"Error:\n{result.stderr}"
        return f"Prediction Output:\n{result.stdout}"
    except Exception as e:
        return f"Execution Failed: {str(e)}"

# Gradio UI setup
with gr.Blocks() as demo:
    gr.Markdown("## 🧪 Molecular Property Prediction using Chemprop and Transformers")

    model_version = gr.Radio(
        ["Vanilla Chemprop", "Updated Fusion Model"],
        label="Select Model Version"
    )
    dataset_name = gr.Radio(["BBBP", "ClinTox"], label="Select Dataset")

    input_type = gr.Radio(["Upload CSV", "Single SMILES"], label="Input Type")

    file_input = gr.File(file_types=[".csv"], label="Upload CSV", visible=True)
    smiles_input = gr.Textbox(label="Enter SMILES string", visible=False)

    def toggle_inputs(choice):
        return {
            file_input: gr.update(visible=(choice == "Upload CSV")),
            smiles_input: gr.update(visible=(choice == "Single SMILES"))
        }

    input_type.change(toggle_inputs, input_type, [file_input, smiles_input])

    predict_button = gr.Button("Predict")
    output = gr.Textbox(label="Output")

    predict_button.click(
        fn=predict,
        inputs=[model_version, dataset_name, input_type, file_input, smiles_input],
        outputs=output
    )

demo.launch()