Spaces:
Sleeping
Sleeping
File size: 4,501 Bytes
1b74e0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from huggingface_hub import hf_hub_download\n",
"from transformers import AutoTokenizer\n",
"\n",
"from model.distilbert import DistilBertClassificationModel\n",
"from model.scibert import SciBertClassificationModel\n",
"from model.llama import LlamaClassificationModel\n",
"from model.t5 import T5ClassificationModel"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Selection\n",
"Uncomment desired `repo_id` and corresponding `model` and input type."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Baseline\n",
"repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n",
"\n",
"# Prompt \n",
"# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-t5-prompt-02-epochs\"\n",
"\n",
"# Initialize the model\n",
"model = DistilBertClassificationModel(repo_id)\n",
"# model = SciBertClassificationModel(repo_id)\n",
"# model = LlamaClassificationModel()\n",
"# model = T5ClassificationModel(repo_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(repo_id)\n",
"\n",
"# Loads classification head weights\n",
"classification_head_path = hf_hub_download(\n",
" repo_id=repo_id,\n",
" repo_type=\"model\",\n",
" filename=\"classification_head.pt\"\n",
")\n",
"\n",
"model.classifier.load_state_dict(torch.load(classification_head_path, map_location=torch.device(\"cpu\")))\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Baseline\n",
"# text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n",
"\n",
"# Prompt\n",
"text = \"What are the likely imperfections that occur in Ti-6Al-4V L-PBF builds at 280.0 W, given a 100.0 microns beam diameter, a 400.0 mm/s scan speed, a 100.0 microns hatch spacing, and a 50.0 microns layer height?\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Tokenize inputs \n",
"inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
"\n",
"# For scibert\n",
"inputs_kwargs = {}\n",
"for key, value in inputs.items():\n",
" if key not in [\"token_type_ids\"]:\n",
" inputs_kwargs[key] = value\n",
"\n",
"# Perform inference\n",
"outputs = model(**inputs_kwargs)\n",
"\n",
"# Extract logits and apply sigmoid activation for multi-label classification\n",
"probs = torch.sigmoid(outputs[\"logits\"])\n",
"\n",
"# Convert probabilities to one-hot encoded labels\n",
"preds = (probs > 0.5).int().squeeze()\n",
"\n",
"# One hot encoded classifications\n",
"classifications = [\"None\", \"Keyhole\", \"Lack of Fusion\", \"Balling\"]\n",
" \n",
"print([classifications[index] for index, encoding in enumerate(preds) if encoding == 1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|