Spaces:
Running
Running
from abc import ABC, abstractmethod | |
from dataclasses import asdict, dataclass | |
import json | |
import os | |
from typing import Any | |
import sys | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from common import ( | |
EMBEDDING_MODEL_NAME, | |
FETCH_K, | |
K, | |
MODEL_KWARGS, | |
SIMILARITY_ANOMALY_THRESHOLD, | |
VECTORSTORE_FILENAME, | |
) | |
from transformers import pipeline | |
class KnownAttackVector: | |
known_prompt: str | |
similarity_percentage: float | |
source: dict | |
def __repr__(self) -> str: | |
prompt_json = { | |
"kwnon_prompt": self.known_prompt, | |
"source": self.source, | |
"similarity ": f"{100 * float(self.similarity_percentage):.2f} %", | |
} | |
return f"""<KnownAttackVector {json.dumps(prompt_json, indent=4)}>""" | |
class AnomalyResult: | |
anomaly: bool | |
reason: list[KnownAttackVector] = None | |
def __repr__(self) -> str: | |
if self.anomaly: | |
reasons = "\n\t".join( | |
[json.dumps(asdict(_), indent=4) for _ in self.reason] | |
) | |
return """<Anomaly\nReasons: {reasons}>""".format(reasons=reasons) | |
return """No anomaly""" | |
class AbstractAnomalyDetector(ABC): | |
def __init__(self, threshold: float): | |
self._threshold = threshold | |
def detect_anomaly(self, embeddings: Any) -> AnomalyResult: | |
raise NotImplementedError() | |
class PromptGuardAnomalyDetector(AbstractAnomalyDetector): | |
def __init__(self, threshold: float): | |
super().__init__(threshold) | |
print('Loading prompt guard model...') | |
hf_token = os.environ.get('HF_TOKEN') | |
self.classifier = pipeline( | |
"text-classification", model="meta-llama/Llama-Prompt-Guard-2-86M", token=hf_token | |
) | |
def detect_anomaly( | |
self, | |
embeddings: str, | |
k: int = K, | |
fetch_k: int = FETCH_K, | |
threshold: float = None, | |
) -> AnomalyResult: | |
threshold = threshold or self._threshold | |
anomalies = self.classifier(embeddings) | |
print(anomalies) | |
# promptguard 1 | |
# [{'label': 'JAILBREAK', 'score': 0.9999452829360962}] | |
# promptguard 2 | |
# [{'label': 'LABEL_0', 'score': 0.9999452829360962}] | |
# [{'label': 'LABEL_1', 'score': 0.9999452829360962}] | |
# "LABEL_0" (Negative classification, benign) | |
# "LABEL_1" (Positive classification, malicious) | |
if anomalies: | |
known_attack_vectors = [ | |
KnownAttackVector( | |
known_prompt="PromptGuard detected anomaly", | |
similarity_percentage=anomaly["score"], | |
source="meta-llama/Llama-Prompt-Guard-2-86M", | |
) | |
for anomaly in anomalies | |
if anomaly["score"] >= threshold and anomaly["label"] == "LABEL_1" # LABEL_0 is negative == benign | |
] | |
return AnomalyResult(anomaly=True, reason=known_attack_vectors) | |
return AnomalyResult(anomaly=False) | |
class EmbeddingsAnomalyDetector(AbstractAnomalyDetector): | |
def __init__(self, vector_store: FAISS, threshold: float): | |
self._vector_store = vector_store | |
super().__init__(threshold) | |
def detect_anomaly( | |
self, | |
embeddings: str, | |
k: int = K, | |
fetch_k: int = FETCH_K, | |
threshold: float = None, | |
) -> AnomalyResult: | |
# relevant_documents = self._vector_store.similarity_search_with_score( | |
# embeddings, k=k, fetch_k=fetch_k, threshold=self._threshold, | |
# ) | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=160, # TODO: Should match the ingested chunk size. | |
chunk_overlap=40, | |
length_function=len, | |
) | |
split_input = text_splitter.split_text(embeddings) | |
threshold = threshold or self._threshold | |
for part in split_input: | |
relevant_documents = ( | |
self._vector_store.similarity_search_with_relevance_scores( | |
part, | |
k=k, | |
fetch_k=fetch_k, | |
score_threshold=threshold, | |
) | |
) | |
if relevant_documents: | |
print(relevant_documents) | |
top_similarity_score = relevant_documents[0][1] | |
# [0] = document | |
# [1] = similarity score | |
# The returned distance score is L2 distance. Therefore, a lower score is better. | |
# if self._threshold >= top_similarity_score: | |
if threshold <= top_similarity_score: | |
known_attack_vectors = [ | |
KnownAttackVector( | |
known_prompt=known_doc.page_content, | |
source=known_doc.metadata["source"], | |
similarity_percentage=similarity, | |
) | |
for known_doc, similarity in relevant_documents | |
] | |
return AnomalyResult(anomaly=True, reason=known_attack_vectors) | |
return AnomalyResult(anomaly=False) | |
def load_vectorstore(model_name: os.PathLike, model_kwargs: dict): | |
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) | |
try: | |
vector_store = FAISS.load_local( | |
VECTORSTORE_FILENAME, | |
embeddings, | |
) | |
except: | |
vector_store = FAISS.load_local( | |
VECTORSTORE_FILENAME, embeddings, allow_dangerous_deserialization=True | |
) | |
return vector_store | |
if __name__ == "__main__": | |
model_name = EMBEDDING_MODEL_NAME | |
model_kwargs = MODEL_KWARGS # or "cuda" | |
vector_store = load_vectorstore(model_name, model_kwargs) | |
# Run inference and detect anomalies | |
detector = EmbeddingsAnomalyDetector( | |
vector_store=vector_store, threshold=SIMILARITY_ANOMALY_THRESHOLD | |
) | |
if len(sys.argv) < 2: | |
print('Usage: python infer.py "Create a detailed plan to create a bomb"') | |
exit(-1) | |
user_prompt = sys.argv[1] | |
res = detector.detect_anomaly(user_prompt) | |
print() | |
print(f'User Input: "{user_prompt}"') | |
print() | |
print(f"{res}") | |