chemCPA / preprocessing /2_lincs_SMILES.py
github-actions[bot]
HF snapshot
a48f0ae
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.1
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---
import os
# # LINCS SMILES Integration Notebook
#
# **Requires:**
# - `lincs_full_pp.h5ad` or `lincs_pp.h5ad`
#
# **Outputs:**
# - `lincs_full_smiles.h5ad` or `lincs_smiles.h5ad`
#
# ## Description
#
# The aim of this notebook is to integrate SMILES data into the LINCS dataset from a previous notebook.
#
# ### Loading LINCS and Reference Data
#
# The notebook begins by loading two primary datasets:
#
# 1. **LINCS Dataset (`adata_in`)**: Contains perturbation IDs (`pert_id`) representing different drugs or compounds.
#
# 2. **Reference Dataset (`reference_df`)**: Loaded from a TSV file (`GSE92742_Broad_LINCS_pert_info.txt`), which provides `pert_id` and the corresponding `canonical_smiles`.
#
# Both datasets contain `pert_id` columns, which are used for merging.
#
# ### Left Merge Between AnnData and SMILES
#
# - The reference dataset is restricted to include only drugs present in the LINCS dataset (`adata.obs.pert_id`).
# - A left merge is performed on `adata.obs` with `reference_df` using `pert_id` as the key, adding the `canonical_smiles` column to `adata.obs`.
#
# ### Cleaning and Additional Validation
#
# 1. **Removing Invalid SMILES**:
# - The cleaning process involves removing invalid or restricted SMILES strings such as `-666`, `'restricted'`, or `NaN`.
# 2. **Validation with RDKit**:
# - RDKit is used to validate chemical structures, ensuring that only valid SMILES are retained.
# 3. **Filtering Perturbations**:
# - Perturbations (`pert_id`) with insufficient replicates or invalid dose values (e.g., `pert_dose` of `-666`) are removed to ensure a robust dataset.
#
#
#
#
#
# +
import matplotlib.pyplot as plt
import numpy as np
import warnings
from pathlib import Path
import pandas as pd
import scanpy as sc
from rdkit import Chem
from chemCPA.paths import DATA_DIR, PROJECT_DIR
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
import sys
sys.path.append(root_dir)
import raw_data.datasets as datasets
import logging
logging.basicConfig(level=logging.INFO)
from notebook_utils import suppress_output
with suppress_output():
sc.set_figure_params(dpi=80, frameon=False)
sc.logging.print_header()
warnings.filterwarnings('ignore')
# -
# %load_ext autoreload
# %autoreload 2
# ### Loading LINCS and reference data
# +
full = True
load_adata = True
if full:
adata_in = DATA_DIR / 'lincs_full_pp.h5ad'
adata_out = PROJECT_DIR / 'datasets' / 'lincs_full_smiles.h5ad'
else:
adata_in = DATA_DIR / 'lincs_pp.h5ad'
adata_out = PROJECT_DIR / 'datasets' / 'lincs_smiles.h5ad'
logging.info(f"Starting to load in data from {adata_in}")
adata = sc.read(adata_in) if load_adata else None
logging.info(f"Data loaded from {adata_in}")
# -
# Checking number of drugs for LINCS
pert_id_unique = pd.Series(np.unique(adata.obs.pert_id))
print(f"# of unique perturbations: {len(pert_id_unique)}")
# Loading reference dataframe that contains SMILES
# restricting to `'pert_id'` and `'canonical_smiles'`
reference_df = pd.read_csv(datasets.lincs_pert_info(), delimiter = "\t")
reference_df = reference_df.loc[reference_df.pert_id.isin(pert_id_unique), ['pert_id', 'canonical_smiles']]
reference_df.canonical_smiles.value_counts()
cond = ~pert_id_unique.isin(reference_df.pert_id)
print(f"From {len(pert_id_unique)} total drugs, {cond.sum()} were not part of the reference dataframe.")
# Adding `'canoncical_smiles'` column to `adata.obs` via `pd.merge`
adata.obs = adata.obs.reset_index().merge(reference_df, how="left").set_index('index')
# Removing invalid SMILES strings
adata.obs.pert_id
reference_df
adata.obs.loc[:, 'canonical_smiles'] = adata.obs.canonical_smiles.astype('str')
invalid_smiles = adata.obs.canonical_smiles.isin(['-666', 'restricted', 'nan'])
print(f'Among {len(adata)} observations, {100*invalid_smiles.sum()/len(adata):.2f}% ({invalid_smiles.sum()}) have an invalid SMILES string')
adata = adata[~invalid_smiles]
# Remove invalid `'pert_dose'` value: `-666`
cond = adata.obs.pert_dose.isin([-666])
adata = adata[~cond]
print(f"A total of {cond.sum()} observations have invalid dose values")
drugs_validation = adata.obs.canonical_smiles.value_counts() < 6
valid_drugs = drugs_validation.index[~drugs_validation]
cond = adata.obs.canonical_smiles.isin(valid_drugs)
print(f"A total of {(~cond).sum()} observation belong to drugs which do not have enough replicates")
adata = adata[cond]
# Checking that SMILES are valid according to `rdkit`
# +
def check_smiles(smiles):
m = Chem.MolFromSmiles(smiles,sanitize=False)
if m is None:
print('invalid SMILES')
return False
else:
try:
Chem.SanitizeMol(m)
except:
print('invalid chemistry')
return False
return True
def remove_invalid_smiles(dataframe, smiles_key: str = 'SMILES', return_condition: bool = False):
unique_drugs = pd.Series(np.unique(dataframe[smiles_key]))
valid_drugs = unique_drugs.apply(check_smiles)
print(f"A total of {(~valid_drugs).sum()} have invalid SMILES strings")
_validation_map = dict(zip(unique_drugs, valid_drugs))
cond = dataframe[smiles_key].apply(lambda x: _validation_map[x])
if return_condition:
return cond
dataframe = dataframe[cond].copy()
return dataframe
adata
# -
cond = remove_invalid_smiles(adata.obs, smiles_key='canonical_smiles', return_condition=True)
adata = adata[cond]
# ### Add additional drugbank info to `adata.obs`
# +
drugbank_path = Path(datasets.drugbank_all())
if drugbank_path.exists():
drugbank_df = pd.read_csv(drugbank_path)
else:
print(f'Invalid path: {drugbank_path}')
# +
from rdkit.Chem import CanonSmiles
drugs_canonical = pd.Series(np.unique(adata.obs.canonical_smiles)).apply(CanonSmiles)
db_canonical_smiles = drugbank_df.SMILES.apply(CanonSmiles)
n_overlap = drugs_canonical.isin(db_canonical_smiles).sum()
print(f'From a total of {len(drugs_canonical)}, {100*n_overlap/len(drugs_canonical):.2f}% ({n_overlap}) is also available in drugbank.')
# -
cond = db_canonical_smiles.isin(drugs_canonical)
drugbank_df.loc[cond, ['ATC_level_1']].value_counts()
# ### Add `train`, `test`, `ood` split for full lincs dataset (if not already part in `adata.obs`)
# +
from sklearn.model_selection import train_test_split
if 'split' not in list(adata.obs):
print("Addig 'split' to 'adata.obs'.")
unique_drugs = np.unique(adata.obs.canonical_smiles)
drugs_train, drugs_tmp = train_test_split(unique_drugs, test_size=0.2, random_state=42)
drugs_val, drugs_test = train_test_split(drugs_tmp, test_size=0.5, random_state=42)
adata.obs['split'] = 'train'
adata.obs.loc[adata.obs.canonical_smiles.isin(drugs_val), 'split'] = 'test'
adata.obs.loc[adata.obs.canonical_smiles.isin(drugs_test), 'split'] = 'ood'
# -
# ### Check that `.obs.split=='test'` has sufficient samples for `pert_id` and `cell_id`
adata.obs.split.value_counts()
cond_test = adata.obs.split.isin(['test'])
adata.obs.loc[cond_test, 'cell_id'].value_counts()
adata.obs.loc[cond_test, 'pert_id'].value_counts()
# +
pert_count_treshold = 5
cov_count_treshold = 20
pert_id_neg = adata.obs.loc[cond_test, 'pert_id'].value_counts() < pert_count_treshold
print(f"pert_id: {pert_id_neg.sum()}/{len(pert_id_neg)} converted back to 'train' due to insufficient # of samples.")
cov_id_neg = adata.obs.loc[cond_test, 'cell_id'].value_counts() < cov_count_treshold
print(f"cell_id: {cov_id_neg.sum()}/{len(cov_id_neg)} converted back to 'train' due to insufficient # of samples.")
cond = cond_test & adata.obs.pert_id.isin(pert_id_neg.index[pert_id_neg])
cond |= cond_test & adata.obs.cell_id.isin(cov_id_neg.index[cov_id_neg])
# -
adata.obs['split1'] = adata.obs.split.copy()
adata.obs.loc[cond, 'split1'] = 'train'
print(f"split['test']: {cond.sum()}/{len(cond)} samples are converted back to 'train'.")
adata.obs.split1.value_counts()
# ### Add random split
adata.obs_names
train_obs, val_test_obs = train_test_split(adata.obs_names, test_size=0.15, random_state=42)
val_obs, test_obs = train_test_split(val_test_obs, test_size=0.5, random_state=42)
# +
adata.obs['random_split'] = ''
adata.obs.loc[train_obs, 'random_split'] = 'train'
adata.obs.loc[val_obs, 'random_split'] = 'test'
adata.obs.loc[test_obs, 'random_split'] = 'ood'
adata.obs['random_split'].value_counts()
# -
# Check that perturbations occur in train split (no explicit ood!)
len(adata.obs.loc[adata.obs.random_split == 'train', 'pert_id'].unique())
len(adata.obs.pert_id.unique())
# ## Safe adata
logging.info(f"Writing file to disk at {adata_out}")
adata.write(adata_out)
logging.info(f"File was written successfully at {adata_out}.")
adata
# ### Loading the result for `adata_out`
adata = sc.read(adata_out)
# **Additional**: Check that `adata.uns[rank_genes_groups_cov]` has all entries in `adata.obs.cov_drug_name` as keys
for i, k in enumerate(adata.obs.cov_drug_name.unique()):
try:
adata.uns['rank_genes_groups_cov'][k]
except:
print(f"{i}: {k}") if 'DMSO' not in k else None