|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
from dawo import DAWO, loss_function, Anndata_to_Tensor |
|
|
|
|
|
class DAWOWrapper: |
|
""" |
|
Minimal wrapper for DAWO model to use with Hugging Face Hub |
|
""" |
|
def __init__(self, repo_path): |
|
""" |
|
Initialize the DAWO model |
|
|
|
Args: |
|
repo_path: Path to repository with model files |
|
""" |
|
|
|
config_path = os.path.join(repo_path, "config.json") |
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
self.model = DAWO( |
|
input_dim_X=config["input_dim_X"], |
|
input_dim_Y=config["input_dim_Y"], |
|
input_dim_Z=config["input_dim_Z"], |
|
latent_dim=config["latent_dim"], |
|
Y_emb=config["Y_emb"], |
|
Z_emb=config["Z_emb"], |
|
num_classes=config["num_classes"] |
|
) |
|
|
|
|
|
self.model.load_state_dict(torch.load(os.path.join(repo_path, "model.pth"))) |
|
self.model.eval() |
|
|
|
def predict(self, x, y, z): |
|
""" |
|
Make predictions with the DAWO model |
|
|
|
Args: |
|
x: Gene expression tensor (batch_size, input_dim_X) |
|
y: Drug feature tensor (batch_size, input_dim_Y) |
|
z: Cell line feature tensor (batch_size, input_dim_Z) |
|
|
|
Returns: |
|
Dict with model outputs |
|
""" |
|
with torch.no_grad(): |
|
x_hat, mu, logvar, y_pred = self.model(x, y, z) |
|
|
|
return { |
|
"x_hat": x_hat, |
|
"mu": mu, |
|
"logvar": logvar, |
|
"y_pred": y_pred, |
|
"probs": torch.softmax(y_pred, dim=1) |
|
} |