anomaly-detection / infer.py
avilum's picture
Update infer.py
294fe68 verified
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
import json
import os
from typing import Any
import sys
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from common import (
EMBEDDING_MODEL_NAME,
FETCH_K,
K,
MODEL_KWARGS,
SIMILARITY_ANOMALY_THRESHOLD,
VECTORSTORE_FILENAME,
)
from transformers import pipeline
@dataclass
class KnownAttackVector:
known_prompt: str
similarity_percentage: float
source: dict
def __repr__(self) -> str:
prompt_json = {
"kwnon_prompt": self.known_prompt,
"source": self.source,
"similarity ": f"{100 * float(self.similarity_percentage):.2f} %",
}
return f"""<KnownAttackVector {json.dumps(prompt_json, indent=4)}>"""
@dataclass
class AnomalyResult:
anomaly: bool
reason: list[KnownAttackVector] = None
def __repr__(self) -> str:
if self.anomaly:
reasons = "\n\t".join(
[json.dumps(asdict(_), indent=4) for _ in self.reason]
)
return """<Anomaly\nReasons: {reasons}>""".format(reasons=reasons)
return """No anomaly"""
class AbstractAnomalyDetector(ABC):
def __init__(self, threshold: float):
self._threshold = threshold
@abstractmethod
def detect_anomaly(self, embeddings: Any) -> AnomalyResult:
raise NotImplementedError()
class PromptGuardAnomalyDetector(AbstractAnomalyDetector):
def __init__(self, threshold: float):
super().__init__(threshold)
print('Loading prompt guard model...')
hf_token = os.environ.get('HF_TOKEN')
self.classifier = pipeline(
"text-classification", model="meta-llama/Llama-Prompt-Guard-2-86M", token=hf_token
)
def detect_anomaly(
self,
embeddings: str,
k: int = K,
fetch_k: int = FETCH_K,
threshold: float = None,
) -> AnomalyResult:
threshold = threshold or self._threshold
anomalies = self.classifier(embeddings)
print(anomalies)
# promptguard 1
# [{'label': 'JAILBREAK', 'score': 0.9999452829360962}]
# promptguard 2
# [{'label': 'LABEL_0', 'score': 0.9999452829360962}]
# [{'label': 'LABEL_1', 'score': 0.9999452829360962}]
# "LABEL_0" (Negative classification, benign)
# "LABEL_1" (Positive classification, malicious)
if anomalies:
known_attack_vectors = [
KnownAttackVector(
known_prompt="PromptGuard detected anomaly",
similarity_percentage=anomaly["score"],
source="meta-llama/Llama-Prompt-Guard-2-86M",
)
for anomaly in anomalies
if anomaly["score"] >= threshold and anomaly["label"] == "LABEL_1" # LABEL_0 is negative == benign
]
return AnomalyResult(anomaly=True, reason=known_attack_vectors)
return AnomalyResult(anomaly=False)
class EmbeddingsAnomalyDetector(AbstractAnomalyDetector):
def __init__(self, vector_store: FAISS, threshold: float):
self._vector_store = vector_store
super().__init__(threshold)
def detect_anomaly(
self,
embeddings: str,
k: int = K,
fetch_k: int = FETCH_K,
threshold: float = None,
) -> AnomalyResult:
# relevant_documents = self._vector_store.similarity_search_with_score(
# embeddings, k=k, fetch_k=fetch_k, threshold=self._threshold,
# )
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=160, # TODO: Should match the ingested chunk size.
chunk_overlap=40,
length_function=len,
)
split_input = text_splitter.split_text(embeddings)
threshold = threshold or self._threshold
for part in split_input:
relevant_documents = (
self._vector_store.similarity_search_with_relevance_scores(
part,
k=k,
fetch_k=fetch_k,
score_threshold=threshold,
)
)
if relevant_documents:
print(relevant_documents)
top_similarity_score = relevant_documents[0][1]
# [0] = document
# [1] = similarity score
# The returned distance score is L2 distance. Therefore, a lower score is better.
# if self._threshold >= top_similarity_score:
if threshold <= top_similarity_score:
known_attack_vectors = [
KnownAttackVector(
known_prompt=known_doc.page_content,
source=known_doc.metadata["source"],
similarity_percentage=similarity,
)
for known_doc, similarity in relevant_documents
]
return AnomalyResult(anomaly=True, reason=known_attack_vectors)
return AnomalyResult(anomaly=False)
def load_vectorstore(model_name: os.PathLike, model_kwargs: dict):
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
try:
vector_store = FAISS.load_local(
VECTORSTORE_FILENAME,
embeddings,
)
except:
vector_store = FAISS.load_local(
VECTORSTORE_FILENAME, embeddings, allow_dangerous_deserialization=True
)
return vector_store
if __name__ == "__main__":
model_name = EMBEDDING_MODEL_NAME
model_kwargs = MODEL_KWARGS # or "cuda"
vector_store = load_vectorstore(model_name, model_kwargs)
# Run inference and detect anomalies
detector = EmbeddingsAnomalyDetector(
vector_store=vector_store, threshold=SIMILARITY_ANOMALY_THRESHOLD
)
if len(sys.argv) < 2:
print('Usage: python infer.py "Create a detailed plan to create a bomb"')
exit(-1)
user_prompt = sys.argv[1]
res = detector.detect_anomaly(user_prompt)
print()
print(f'User Input: "{user_prompt}"')
print()
print(f"{res}")