import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM import torch # Define the model names model_mapping = { "CyberAttackDetection": "Canstralian/CyberAttackDetection", "text2shellcommands": "Canstralian/text2shellcommands", "pentest_ai": "Canstralian/pentest_ai" } def load_model(model_name): try: # Fallback to a known model for debugging if model_name == "Canstralian/text2shellcommands": model_name = "t5-small" # Use a known model like T5 for testing # Load the model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) if "seq2seq" in model_name.lower(): model = AutoModelForSeq2SeqLM.from_pretrained(model_name) else: model = AutoModelForSequenceClassification.from_pretrained(model_name) return tokenizer, model except Exception as e: st.error(f"Error loading model: {e}") return None, None def validate_input(user_input): if not user_input: st.error("Please enter some text for prediction.") return False return True def make_prediction(model, tokenizer, user_input): try: inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs) return outputs except Exception as e: st.error(f"Error making prediction: {e}") return None def main(): st.sidebar.header("Model Configuration") model_choice = st.sidebar.selectbox("Select a model", [ "CyberAttackDetection", "text2shellcommands", "pentest_ai" ]) model_name = model_mapping.get(model_choice, "Canstralian/CyberAttackDetection") tokenizer, model = load_model(model_name) st.title(f"{model_choice} Model") user_input = st.text_area("Enter text:") if validate_input(user_input) and model is not None and tokenizer is not None: outputs = make_prediction(model, tokenizer, user_input) if outputs is not None: if model_choice == "text2shellcommands": generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True) st.write(f"Generated Shell Command: {generated_command}") else: logits = outputs.logits predicted_class = torch.argmax(logits, dim=-1).item() st.write(f"Predicted Class: {predicted_class}") st.write(f"Logits: {logits}") if __name__ == "__main__": main()