rezaenayati's picture
Update app.py
7410ec2 verified
raw
history blame
1.31 kB
import gradio as gr
import numpy as np
from PIL import Image, ImageOps
import tensorflow as tf
# Load the model
model = tf.keras.models.load_model("cnn_model.h5")
def predict(image_array):
try:
if isinstance(image_array, dict) and 'composite' in image_array:
# Extract the image data from the 'composite' key
image_array = image_array['composite']
if image_array is None or np.sum(image_array) == 0:
return "Please draw a digit."
image = Image.fromarray(image_array.astype("uint8"), mode="L")
image = image.resize((28, 28))
image = image.point(lambda x: 0 if x < 128 else 255, 'L')
image = image.resize((28, 28))
image_array = np.array(image).astype("float32") / 255.0
image_array = image_array.reshape(1, 28, 28, 1)
logits = model.predict(image_array)
prediction = int(np.argmax(logits))
confidence = float(tf.nn.softmax(logits)[0][prediction])
return f"Digit: {prediction} (confidence: {confidence:.2%})"
except Exception as err:
return f"Runtime error: {str(err)}"
gr.Interface(
fn=predict,
inputs=gr.Sketchpad(image_mode="L", canvas_size=(200, 200), type="numpy"),
outputs="text",
title="Digit Classifier"
).launch()