import numpy as np import gradio as gr import tensorflow as tf import cv2 # App title title = "Welcome to your first sketch recognition app!" # App description head = ( "
" "" "

The model is trained to classify numbers (from 0 to 9). " "To test it, draw your number in the space provided.

" "
" ) # GitHub repository link ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)." # Class names (from 0 to 9) labels = { 0: "zero", 1: "one", 2: "two", 3: "three", 4: "four", 5: "five", 6: "six", 7: "seven", 8: "eight", 9: "nine" } # Load model (trained on MNIST dataset) model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5") def predict(data): # Convert to NumPy array img = np.array(data['composite']) # print non-zero values print("non-zero values", np.count_nonzero(img), img.shape) for i in range(img.shape[0]): for j in range(img.shape[1]): if img[i][j] > 0: print(i, j, img[i][j]) print("img.shape", img.shape) # Handle RGBA or RGB images if img.shape[-1] == 4: # RGBA img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) if img.shape[-1] == 3: # RGB img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # Resize image to 28x28 img = cv2.resize(img, (28, 28)) # Normalize pixel values to [0, 1] img = img / 255.0 # Reshape to match model input img = img.reshape(1, 28, 28, 1) print("img", img) # Model predictions preds = model.predict(img)[0] print("preds", preds) values_map = {preds[i]: i for i in range(len(preds))} sorted_values = sorted(preds, reverse=True) labels_map = dict() for i in range(3): print("sorted_values[i]", sorted_values[i], values_map[sorted_values[i]]) labels_map[labels[values_map[sorted_values[i]]]] = sorted_values[i] print("labels_map", labels_map) return labels_map # Top 3 classes label = gr.Label(num_top_classes=3) # Open Gradio interface for sketch recognition interface = gr.Interface( fn=predict, inputs=gr.Sketchpad(type='numpy', image_mode='L', brush=gr.Brush()), outputs=label, title=title, description=head, article=ref ) interface.launch(share=True)