Spaces:
Running
Running
import argparse | |
import os | |
import sys | |
import warnings | |
import numpy as np | |
import pandas as pd | |
import torch | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
from utils import seed_everything | |
warnings.filterwarnings("ignore") | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Search for similar reactions.") | |
parser.add_argument( | |
"--input_data", | |
type=str, | |
required=True, | |
help="Path to the input data.", | |
) | |
parser.add_argument( | |
"--target_embedding", | |
type=str, | |
required=True, | |
help="Path to the target embedding.", | |
) | |
parser.add_argument( | |
"--query_embedding", | |
type=str, | |
required=True, | |
help="Path to the target embedding.", | |
) | |
parser.add_argument( | |
"--top_k", | |
type=int, | |
default=1, | |
help="Number of similar reactions to retrieve.", | |
) | |
parser.add_argument("--batch_size", type=int, default=64, help="Batch size.") | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="./", | |
help="Directory where results are saved.", | |
) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
config = parse_args() | |
seed_everything(42) | |
target_embedding = np.load(config.target_embedding) | |
query_embedding = np.load(config.query_embedding) | |
target_embedding = torch.tensor(target_embedding, dtype=torch.float32).cuda() | |
query_embedding = torch.tensor(query_embedding, dtype=torch.float32).cuda() | |
target_embedding = torch.nn.functional.normalize(target_embedding, p=2, dim=1) | |
query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1) | |
batch_size = config.batch_size | |
distances = [] | |
for i in range(0, query_embedding.shape[0], batch_size): | |
print(f"Processing batch {i // batch_size}...") | |
batch = query_embedding[i : i + batch_size] | |
similarity = torch.matmul(batch, target_embedding.T) | |
distance, _ = torch.max(similarity, dim=1) | |
distances.append(distance.cpu().tolist()) | |
distances = np.concatenate(distances) | |
df = pd.read_csv(config.input_data) | |
df["distance"] = distances | |
df.to_csv(os.path.join(config.output_dir, "distance.csv"), index=False) | |