scontess commited on
Commit
f394b7f
ยท
1 Parent(s): b123fe6
Files changed (2) hide show
  1. src/streamlit_app.py +14 -35
  2. src/test_model.py +33 -0
src/streamlit_app.py CHANGED
@@ -1,7 +1,6 @@
1
  import streamlit as st
2
  import tensorflow as tf
3
  import numpy as np
4
- import time
5
  import tensorflow.keras as keras
6
  from tensorflow.keras.applications import VGG16
7
  from tensorflow.keras.layers import Dense, Flatten
@@ -11,31 +10,16 @@ import matplotlib.pyplot as plt
11
  from sklearn.model_selection import train_test_split
12
  from sklearn.metrics import confusion_matrix, classification_report
13
  import seaborn as sns
14
- from huggingface_hub import HfApi
15
  import os
16
 
17
- # ๐Ÿ“Œ Percorso della cache
18
- os.environ["HF_HOME"] = "/app/.cache"
19
- os.environ["HF_DATASETS_CACHE"] = "/app/.cache"
20
- HF_TOKEN = os.getenv("HF_TOKEN")
21
-
22
- # ๐Ÿ“Œ Autenticazione Hugging Face
23
- if HF_TOKEN:
24
- api = HfApi()
25
- user_info = api.whoami(HF_TOKEN)
26
- st.write(f"โœ… Autenticato come {user_info.get('name', 'Utente sconosciuto')}")
27
- else:
28
- st.warning("โš ๏ธ Nessun token API trovato! Verifica il Secret nello Space.")
29
-
30
  # ๐Ÿ“Œ Caricamento del dataset
31
  st.write("๐Ÿ”„ Caricamento di 300 immagini da `tiny-imagenet`...")
32
  dataset = load_dataset("zh-plus/tiny-imagenet", split="train")
33
 
34
  image_list = []
35
  label_list = []
36
-
37
  for i, sample in enumerate(dataset):
38
- if i >= 300: # Prende solo 300 immagini
39
  break
40
  image = tf.image.resize(sample["image"], (64, 64)) / 255.0 # Normalizzazione
41
  image_list.append(image.numpy())
@@ -54,11 +38,11 @@ st.write(f"๐Ÿ“Š **Validation:** {X_val.shape[0]} immagini")
54
  force_training = st.checkbox("๐Ÿ”„ Rifai il training anche se Silva.h5 esiste")
55
 
56
  # ๐Ÿ“Œ Caricamento o training del modello
57
- history = None # ๐Ÿ›  Inizializza history
58
 
59
  if os.path.exists("Silva.h5") and not force_training:
60
  model = load_model("Silva.h5")
61
- st.write("โœ… Modello `Silva.h5` caricato, nessun nuovo training necessario!")
62
  else:
63
  st.write("๐Ÿš€ Training in corso...")
64
  base_model = VGG16(weights="imagenet", include_top=False, input_shape=(64, 64, 3))
@@ -75,7 +59,7 @@ else:
75
 
76
  history = model.fit(X_train, y_train, epochs=10, validation_data=(X_val, y_val))
77
  model.save("Silva.h5")
78
- st.write("โœ… Modello salvato come `Silva.h5`!")
79
 
80
  # ๐Ÿ“Œ Calcolo delle metriche sulla validazione
81
  y_pred_val = np.argmax(model.predict(X_val), axis=1)
@@ -83,38 +67,33 @@ accuracy_val = np.mean(y_pred_val == y_val)
83
  rmse_val = np.sqrt(np.mean((y_pred_val - y_val) ** 2))
84
  report_val = classification_report(y_val, y_pred_val, output_dict=True)
85
 
86
- recall_val = report_val["weighted avg"]["recall"]
87
- precision_val = report_val["weighted avg"]["precision"]
88
- f1_score_val = report_val["weighted avg"]["f1-score"]
89
-
90
  st.write(f"๐Ÿ“Š **Validation Accuracy:** {accuracy_val:.4f}")
91
  st.write(f"๐Ÿ“Š **Validation RMSE:** {rmse_val:.4f}")
92
- st.write(f"๐Ÿ“Š **Validation Precision:** {precision_val:.4f}")
93
- st.write(f"๐Ÿ“Š **Validation Recall:** {recall_val:.4f}")
94
- st.write(f"๐Ÿ“Š **Validation F1-Score:** {f1_score_val:.4f}")
95
 
96
- # ๐Ÿ“Œ Bottone per generare la matrice di confusione sulla validazione
97
- if st.button("๐Ÿ”Ž Genera matrice di confusione per validazione"):
98
  conf_matrix_val = confusion_matrix(y_val, y_pred_val)
99
  fig, ax = plt.subplots(figsize=(10, 7))
100
  sns.heatmap(conf_matrix_val, annot=True, cmap="Blues", fmt="d", ax=ax)
101
  st.pyplot(fig)
102
- st.write("โœ… Matrice di confusione generata!")
103
 
104
- # ๐Ÿ“Œ Grafico per Loss e Accuracy con validazione
105
  if history is not None:
106
  fig, ax = plt.subplots(1, 2, figsize=(12, 5))
107
  ax[0].plot(history.history["loss"], label="Training Loss")
108
  ax[0].plot(history.history["val_loss"], label="Validation Loss")
109
  ax[1].plot(history.history["accuracy"], label="Training Accuracy")
110
  ax[1].plot(history.history["val_accuracy"], label="Validation Accuracy")
111
- ax[0].set_title("Loss durante il training e validazione")
112
- ax[1].set_title("Accuracy durante il training e validazione")
113
  ax[0].legend()
114
  ax[1].legend()
115
  st.pyplot(fig)
116
- else:
117
- st.warning("โš ๏ธ Nessun training eseguito, impossibile mostrare il grafico!")
 
 
 
 
118
 
119
  # ๐Ÿ“Œ Bottone per scaricare il modello
120
  if os.path.exists("Silva.h5"):
 
1
  import streamlit as st
2
  import tensorflow as tf
3
  import numpy as np
 
4
  import tensorflow.keras as keras
5
  from tensorflow.keras.applications import VGG16
6
  from tensorflow.keras.layers import Dense, Flatten
 
10
  from sklearn.model_selection import train_test_split
11
  from sklearn.metrics import confusion_matrix, classification_report
12
  import seaborn as sns
 
13
  import os
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # ๐Ÿ“Œ Caricamento del dataset
16
  st.write("๐Ÿ”„ Caricamento di 300 immagini da `tiny-imagenet`...")
17
  dataset = load_dataset("zh-plus/tiny-imagenet", split="train")
18
 
19
  image_list = []
20
  label_list = []
 
21
  for i, sample in enumerate(dataset):
22
+ if i >= 300:
23
  break
24
  image = tf.image.resize(sample["image"], (64, 64)) / 255.0 # Normalizzazione
25
  image_list.append(image.numpy())
 
38
  force_training = st.checkbox("๐Ÿ”„ Rifai il training anche se Silva.h5 esiste")
39
 
40
  # ๐Ÿ“Œ Caricamento o training del modello
41
+ history = None
42
 
43
  if os.path.exists("Silva.h5") and not force_training:
44
  model = load_model("Silva.h5")
45
+ st.write("โœ… Modello `Silva.h5` caricato!")
46
  else:
47
  st.write("๐Ÿš€ Training in corso...")
48
  base_model = VGG16(weights="imagenet", include_top=False, input_shape=(64, 64, 3))
 
59
 
60
  history = model.fit(X_train, y_train, epochs=10, validation_data=(X_val, y_val))
61
  model.save("Silva.h5")
62
+ st.write("โœ… Modello salvato!")
63
 
64
  # ๐Ÿ“Œ Calcolo delle metriche sulla validazione
65
  y_pred_val = np.argmax(model.predict(X_val), axis=1)
 
67
  rmse_val = np.sqrt(np.mean((y_pred_val - y_val) ** 2))
68
  report_val = classification_report(y_val, y_pred_val, output_dict=True)
69
 
 
 
 
 
70
  st.write(f"๐Ÿ“Š **Validation Accuracy:** {accuracy_val:.4f}")
71
  st.write(f"๐Ÿ“Š **Validation RMSE:** {rmse_val:.4f}")
72
+ st.write(f"๐Ÿ“Š **Validation F1-Score:** {report_val['weighted avg']['f1-score']:.4f}")
 
 
73
 
74
+ # ๐Ÿ“Œ Bottone per generare la matrice di confusione
75
+ if st.button("๐Ÿ”Ž Genera matrice di confusione"):
76
  conf_matrix_val = confusion_matrix(y_val, y_pred_val)
77
  fig, ax = plt.subplots(figsize=(10, 7))
78
  sns.heatmap(conf_matrix_val, annot=True, cmap="Blues", fmt="d", ax=ax)
79
  st.pyplot(fig)
 
80
 
81
+ # ๐Ÿ“Œ Grafico per Loss e Accuracy
82
  if history is not None:
83
  fig, ax = plt.subplots(1, 2, figsize=(12, 5))
84
  ax[0].plot(history.history["loss"], label="Training Loss")
85
  ax[0].plot(history.history["val_loss"], label="Validation Loss")
86
  ax[1].plot(history.history["accuracy"], label="Training Accuracy")
87
  ax[1].plot(history.history["val_accuracy"], label="Validation Accuracy")
 
 
88
  ax[0].legend()
89
  ax[1].legend()
90
  st.pyplot(fig)
91
+
92
+ # ๐Ÿ“Œ Bottone per avviare il test su nuove immagini
93
+ if st.button("๐Ÿ”Ž Testa il modello con un'immagine nuova"):
94
+ st.write("๐Ÿš€ Avviando il test...")
95
+ os.system("streamlit run test_model.py")
96
+
97
 
98
  # ๐Ÿ“Œ Bottone per scaricare il modello
99
  if os.path.exists("Silva.h5"):
src/test_model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from tensorflow.keras.models import load_model
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+
8
+ # ๐Ÿ“Œ Carica il modello
9
+ MODEL_PATH = "Silva.h5"
10
+
11
+ if not MODEL_PATH:
12
+ st.error("โŒ Modello non trovato! Esegui prima il training.")
13
+ else:
14
+ model = load_model(MODEL_PATH)
15
+ st.write("โœ… Modello caricato correttamente!")
16
+
17
+ # ๐Ÿ“Œ Carica un'immagine per il test
18
+ uploaded_file = st.file_uploader("๐Ÿ“ค Carica un'immagine per testare il modello", type=["jpg", "png", "jpeg"])
19
+
20
+ if uploaded_file:
21
+ # Converti l'immagine in formato compatibile
22
+ image = Image.open(uploaded_file).convert("RGB")
23
+ image = image.resize((64, 64)) # ๐Ÿ“Œ Stessa dimensione usata nel training
24
+ image_array = np.array(image) / 255.0 # Normalizzazione
25
+ image_array = np.expand_dims(image_array, axis=0) # Aggiungi batch dimension
26
+
27
+ st.image(image, caption="๐Ÿ” Immagine di test", use_column_width=True)
28
+
29
+ # ๐Ÿ“Œ Esegui la previsione
30
+ prediction = model.predict(image_array)
31
+ predicted_class = np.argmax(prediction)
32
+
33
+ st.write(f"๐Ÿ”ฎ **Classe Predetta:** {predicted_class}")