File size: 1,952 Bytes
b33b5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import json
import torch
import numpy as np
from dawo import DAWO, loss_function, Anndata_to_Tensor


class DAWOWrapper:
    """
    Minimal wrapper for DAWO model to use with Hugging Face Hub
    """
    def __init__(self, repo_path):
        """
        Initialize the DAWO model
        
        Args:
            repo_path: Path to repository with model files
        """
        # Load configuration
        config_path = os.path.join(repo_path, "config.json")
        with open(config_path, 'r') as f:
            config = json.load(f)
            
        # Create model with original DAWO class
        self.model = DAWO(
            input_dim_X=config["input_dim_X"],
            input_dim_Y=config["input_dim_Y"],
            input_dim_Z=config["input_dim_Z"],
            latent_dim=config["latent_dim"],
            Y_emb=config["Y_emb"],
            Z_emb=config["Z_emb"],
            num_classes=config["num_classes"]
        )
        
        # Load weights
        self.model.load_state_dict(torch.load(os.path.join(repo_path, "model.pth")))
        self.model.eval()
    
    def predict(self, x, y, z):
        """
        Make predictions with the DAWO model
        
        Args:
            x: Gene expression tensor (batch_size, input_dim_X)
            y: Drug feature tensor (batch_size, input_dim_Y)
            z: Cell line feature tensor (batch_size, input_dim_Z)
            
        Returns:
            Dict with model outputs
        """
        with torch.no_grad():
            x_hat, mu, logvar, y_pred = self.model(x, y, z)
            
        return {
            "x_hat": x_hat,              # Reconstructed gene expression
            "mu": mu,                    # Latent mean
            "logvar": logvar,            # Latent log variance
            "y_pred": y_pred,            # Drug response predictions
            "probs": torch.softmax(y_pred, dim=1)  # Drug response probabilities
        }