cels / app.py
alexandraroze's picture
solution
50bd1fc
raw
history blame contribute delete
910 Bytes
import streamlit as st
import torch
from src.inference import CrossAttentionInference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference = CrossAttentionInference(
model_path="best_attention_classifier.pth",
device=device
)
st.title("Random Image Inference")
st.write(
"Нажмите кнопку ниже, чтобы сгенерировать пару случайных изображений и получить предсказание модели."
)
if st.button("Сгенерировать изображения"):
pred_label, (img1, img2) = inference.predict_random_pair()
col1, col2 = st.columns(2)
with col1:
st.image(img1, caption="Image 1", use_container_width=True)
with col2:
st.image(img2, caption="Image 2", use_container_width=True)
st.write(f"**Предсказанная метка**: {pred_label}")