rezaenayati's picture
Update app.py
f08c441 verified
raw
history blame
1.55 kB
import gradio as gr
import numpy as np
from PIL import Image, ImageOps
import tensorflow as tf
# Load the model
model = tf.keras.models.load_model("cnn_model.h5")
def predict(image_array):
try:
if isinstance(image_array, dict) and 'composite' in image_array:
image_array = image_array['composite']
if image_array is None or np.sum(image_array) == 0:
return "Please draw a digit."
image = Image.fromarray(image_array.astype("uint8"), mode="L")
image = image.resize((28, 28), Image.LANCZOS)
image = ImageOps.invert(image)
image_array = np.array(image).astype("float32") / 255.0
image_array = image_array.reshape(1, 28, 28, 1)
# Make prediction
logits = model.predict(image_array, verbose=0)
prediction = int(np.argmax(logits))
confidence = float(tf.nn.softmax(logits)[0][prediction])
return f"Digit: {prediction} (confidence: {confidence:.2%})"
except Exception as err:
return f"Runtime error: {str(err)}"
gr.Interface(
fn=predict,
inputs=gr.Sketchpad(
image_mode="L",
canvas_size=(280, 280), # Larger canvas
type="numpy",
brush_radius=10, # Thicker brush
),
outputs="text",
title="EMNIST Digit Classifier",
description="Draw a digit (0-9) in the center of the canvas. For best results with EMNIST, make your digit large and clear."
).launch()