rezaenayati commited on
Commit
2684828
·
verified ·
1 Parent(s): f08c441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -9,43 +9,44 @@ model = tf.keras.models.load_model("cnn_model.h5")
9
 
10
  def predict(image_array):
11
  try:
12
-
13
  if isinstance(image_array, dict) and 'composite' in image_array:
 
14
  image_array = image_array['composite']
15
 
16
  if image_array is None or np.sum(image_array) == 0:
17
  return "Please draw a digit."
18
 
19
-
20
  image = Image.fromarray(image_array.astype("uint8"), mode="L")
21
-
22
-
23
  image = image.resize((28, 28), Image.LANCZOS)
 
24
 
25
- image = ImageOps.invert(image)
26
-
27
  image_array = np.array(image).astype("float32") / 255.0
28
  image_array = image_array.reshape(1, 28, 28, 1)
29
 
30
  # Make prediction
31
  logits = model.predict(image_array, verbose=0)
32
  prediction = int(np.argmax(logits))
33
-
34
  confidence = float(tf.nn.softmax(logits)[0][prediction])
35
 
36
  return f"Digit: {prediction} (confidence: {confidence:.2%})"
37
  except Exception as err:
38
  return f"Runtime error: {str(err)}"
39
 
40
- gr.Interface(
 
41
  fn=predict,
42
  inputs=gr.Sketchpad(
43
  image_mode="L",
44
- canvas_size=(280, 280), # Larger canvas
45
  type="numpy",
46
- brush_radius=10, # Thicker brush
47
  ),
48
  outputs="text",
49
  title="EMNIST Digit Classifier",
50
- description="Draw a digit (0-9) in the center of the canvas. For best results with EMNIST, make your digit large and clear."
51
- ).launch()
 
 
 
9
 
10
  def predict(image_array):
11
  try:
12
+ # Handle dictionary input format from newer Gradio versions
13
  if isinstance(image_array, dict) and 'composite' in image_array:
14
+ # Extract the image data from the 'composite' key
15
  image_array = image_array['composite']
16
 
17
  if image_array is None or np.sum(image_array) == 0:
18
  return "Please draw a digit."
19
 
20
+ # Process the image for EMNIST format
21
  image = Image.fromarray(image_array.astype("uint8"), mode="L")
 
 
22
  image = image.resize((28, 28), Image.LANCZOS)
23
+ image = ImageOps.invert(image) # Invert colors
24
 
25
+ # Convert to model input format
 
26
  image_array = np.array(image).astype("float32") / 255.0
27
  image_array = image_array.reshape(1, 28, 28, 1)
28
 
29
  # Make prediction
30
  logits = model.predict(image_array, verbose=0)
31
  prediction = int(np.argmax(logits))
 
32
  confidence = float(tf.nn.softmax(logits)[0][prediction])
33
 
34
  return f"Digit: {prediction} (confidence: {confidence:.2%})"
35
  except Exception as err:
36
  return f"Runtime error: {str(err)}"
37
 
38
+ # Create the Gradio interface with appropriate settings
39
+ interface = gr.Interface(
40
  fn=predict,
41
  inputs=gr.Sketchpad(
42
  image_mode="L",
43
+ canvas_size=(280, 280),
44
  type="numpy",
45
+ brush=gr.Brush() # Using default brush settings
46
  ),
47
  outputs="text",
48
  title="EMNIST Digit Classifier",
49
+ description="Draw a digit (0-9) in the center of the canvas. Make it large and clear for best results."
50
+ )
51
+
52
+ interface.launch()