Spaces:
Running
Running
File size: 2,289 Bytes
08ccc8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import argparse
import os
import sys
import warnings
import numpy as np
import pandas as pd
import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import seed_everything
warnings.filterwarnings("ignore")
def parse_args():
parser = argparse.ArgumentParser(description="Search for similar reactions.")
parser.add_argument(
"--input_data",
type=str,
required=True,
help="Path to the input data.",
)
parser.add_argument(
"--target_embedding",
type=str,
required=True,
help="Path to the target embedding.",
)
parser.add_argument(
"--query_embedding",
type=str,
required=True,
help="Path to the target embedding.",
)
parser.add_argument(
"--top_k",
type=int,
default=1,
help="Number of similar reactions to retrieve.",
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size.")
parser.add_argument(
"--output_dir",
type=str,
default="./",
help="Directory where results are saved.",
)
return parser.parse_args()
if __name__ == "__main__":
config = parse_args()
seed_everything(42)
target_embedding = np.load(config.target_embedding)
query_embedding = np.load(config.query_embedding)
target_embedding = torch.tensor(target_embedding, dtype=torch.float32).cuda()
query_embedding = torch.tensor(query_embedding, dtype=torch.float32).cuda()
target_embedding = torch.nn.functional.normalize(target_embedding, p=2, dim=1)
query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1)
batch_size = config.batch_size
distances = []
for i in range(0, query_embedding.shape[0], batch_size):
print(f"Processing batch {i // batch_size}...")
batch = query_embedding[i : i + batch_size]
similarity = torch.matmul(batch, target_embedding.T)
distance, _ = torch.max(similarity, dim=1)
distances.append(distance.cpu().tolist())
distances = np.concatenate(distances)
df = pd.read_csv(config.input_data)
df["distance"] = distances
df.to_csv(os.path.join(config.output_dir, "distance.csv"), index=False)
|