In [1]:
import ast
import numpy as np
import random
import torch

from datasets import load_dataset
from huggingface_hub import hf_hub_download
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm

from model.llama import LlamaClassificationModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set the seed for Python's random module
random.seed(42)

# Set the seed for NumPy
np.random.seed(42)

# Set the seed for PyTorch
torch.manual_seed(42)

# Ensure reproducibility on GPUs
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)  # For multi-GPU setups

# Optional: Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# Initialize the model
model = LlamaClassificationModel()
model.eval()  # Set the model to evaluation mode

LlamaClassificationModel(
  (base_model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    

In [4]:
# Load dataset
dataset = load_dataset("ppak10/melt-pool-classification")
train_dataset = dataset["train_prompt"]
test_dataset = dataset["test_prompt"]
validation_dataset = dataset["validation_prompt"]

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("ppak10/defect-classification-llama-baseline-20-epochs")

# Preprocessing function
def preprocess_function(examples):
    examples["label"] = [ast.literal_eval(label) for label in examples["label"]]
    examples["label"] = [np.array(label, dtype=np.float32) for label in examples["label"]]
    return tokenizer(
        examples["text"], truncation=True, padding="max_length", max_length=256
    )

train_dataset_tokenized = train_dataset.map(preprocess_function, batched=True, num_proc=64)
test_dataset_tokenized = test_dataset.map(preprocess_function, batched=True, num_proc=32)
validation_dataset_tokenized = validation_dataset.map(preprocess_function, batched=True, num_proc=32)

In [5]:
classification_head_path = hf_hub_download(
    repo_id="ppak10/defect-classification-llama-baseline-20-epochs",
    repo_type="model",
    filename="classification_head.pt"
)

model.classifier.load_state_dict(torch.load(classification_head_path))
model.eval()  # Set the model to evaluation mode

  model.classifier.load_state_dict(torch.load(classification_head_path))


LlamaClassificationModel(
  (base_model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    

In [6]:
# Ensure the model is on the GPU
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define the batch size
batch_size = 64

# Create a DataLoader for the validation dataset
validation_loader = DataLoader(validation_dataset_tokenized, batch_size=batch_size, shuffle=False)

def label_to_classifications_batch(labels):
    classifications = ["Desirable", "Keyhole", "Lack of Fusion", "Balling"]
    
    results = []
    for label in labels:  # Iterate over each label in the batch
        result = [classifications[index] for index, encoding in enumerate(label) if encoding == 1]
        results.append(result)
    return results

accuracy_total = 0

# Process the validation dataset in batches
for batch in tqdm(validation_loader):
    texts = batch["text"]
    labels = np.array(batch["label"]).T

    # Move labels to GPU
    # print(np.array(labels))
    labels = torch.tensor(labels).to(device)

    # Tokenize input for the entire batch and move to GPU
    inputs = tokenizer(list(texts), return_tensors="pt", truncation=True, padding="max_length", max_length=256)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    # Perform inference
    outputs = model(**inputs)

    # Extract logits and apply sigmoid activation for multi-label classification
    logits = outputs["logits"]
    probs = torch.sigmoid(logits)

    # Convert probabilities to one-hot encoded labels
    preds = (probs > 0.5).int()

    # Compute accuracy for the batch
    accuracy_per_label = (preds == labels).float().mean(dim=1)  # Mean per sample
    accuracy_batch_mean = accuracy_per_label.mean().item()  # Mean for the batch

    accuracy_total += accuracy_batch_mean * len(labels)  # Weighted addition for overall accuracy

# Calculate overall accuracy
overall_accuracy = accuracy_total / len(validation_dataset_tokenized)
print(f"Overall Accuracy: {overall_accuracy}")

100%|██████████| 11325/11325 [6:18:13<00:00,  2.00s/it] 

Overall Accuracy: 0.4907985512251359



