rezaenayati commited on
Commit
a8538e2
·
verified ·
1 Parent(s): 7410ec2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -7
app.py CHANGED
@@ -8,6 +8,7 @@ model = tf.keras.models.load_model("cnn_model.h5")
8
 
9
  def predict(image_array):
10
  try:
 
11
  if isinstance(image_array, dict) and 'composite' in image_array:
12
  # Extract the image data from the 'composite' key
13
  image_array = image_array['composite']
@@ -15,22 +16,101 @@ def predict(image_array):
15
  if image_array is None or np.sum(image_array) == 0:
16
  return "Please draw a digit."
17
 
 
18
  image = Image.fromarray(image_array.astype("uint8"), mode="L")
19
- image = image.resize((28, 28))
20
- image = image.point(lambda x: 0 if x < 128 else 255, 'L')
21
- image = image.resize((28, 28))
 
 
 
 
 
 
 
22
  image_array = np.array(image).astype("float32") / 255.0
23
  image_array = image_array.reshape(1, 28, 28, 1)
24
- logits = model.predict(image_array)
 
 
25
  prediction = int(np.argmax(logits))
 
 
 
26
  confidence = float(tf.nn.softmax(logits)[0][prediction])
 
27
  return f"Digit: {prediction} (confidence: {confidence:.2%})"
28
  except Exception as err:
29
  return f"Runtime error: {str(err)}"
30
 
 
31
  gr.Interface(
32
  fn=predict,
33
- inputs=gr.Sketchpad(image_mode="L", canvas_size=(200, 200), type="numpy"),
 
 
 
 
 
34
  outputs="text",
35
- title="Digit Classifier"
36
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def predict(image_array):
10
  try:
11
+ # Handle dictionary input format from newer Gradio versions
12
  if isinstance(image_array, dict) and 'composite' in image_array:
13
  # Extract the image data from the 'composite' key
14
  image_array = image_array['composite']
 
16
  if image_array is None or np.sum(image_array) == 0:
17
  return "Please draw a digit."
18
 
19
+ # Create PIL image from array
20
  image = Image.fromarray(image_array.astype("uint8"), mode="L")
21
+
22
+ # EMNIST digits dataset might need specific preprocessing
23
+ # Resize to 28x28 first
24
+ image = image.resize((28, 28), Image.LANCZOS)
25
+
26
+ # EMNIST may have different orientation than your drawing
27
+ # Try rotating/flipping if needed
28
+ image = ImageOps.invert(image) # Invert colors
29
+
30
+ # Convert to model input format - EXACTLY as in your training code
31
  image_array = np.array(image).astype("float32") / 255.0
32
  image_array = image_array.reshape(1, 28, 28, 1)
33
+
34
+ # Make prediction
35
+ logits = model.predict(image_array, verbose=0)
36
  prediction = int(np.argmax(logits))
37
+
38
+ # Since your model uses linear activation and SparseCategoricalCrossentropy(from_logits=True),
39
+ # we need to apply softmax to get probabilities
40
  confidence = float(tf.nn.softmax(logits)[0][prediction])
41
+
42
  return f"Digit: {prediction} (confidence: {confidence:.2%})"
43
  except Exception as err:
44
  return f"Runtime error: {str(err)}"
45
 
46
+ # Create the interface with settings optimized for EMNIST
47
  gr.Interface(
48
  fn=predict,
49
+ inputs=gr.Sketchpad(
50
+ image_mode="L",
51
+ canvas_size=(280, 280), # Larger canvas
52
+ type="numpy",
53
+ brush_radius=10, # Thicker brush
54
+ ),
55
  outputs="text",
56
+ title="EMNIST Digit Classifier",
57
+ description="Draw a digit (0-9) in the center of the canvas. For best results with EMNIST, make your digit large and clear."
58
+ ).launch()
59
+ This updated version should better match how your model was trained with the EMNIST dataset. EMNIST digits might have different characteristics than standard MNIST, which could affect how preprocessing should be done.
60
+ If you're still having accuracy issues, let's try to add a debug output that shows the processed image before it goes to the model. This will help us understand if the preprocessing is correct:
61
+ pythonimport gradio as gr
62
+ import numpy as np
63
+ from PIL import Image, ImageOps
64
+ import tensorflow as tf
65
+
66
+ # Load the model
67
+ model = tf.keras.models.load_model("digit_model(CNN).h5")
68
+
69
+ def predict(image_array):
70
+ try:
71
+ # Handle dictionary input format from newer Gradio versions
72
+ if isinstance(image_array, dict) and 'composite' in image_array:
73
+ # Extract the image data from the 'composite' key
74
+ image_array = image_array['composite']
75
+
76
+ if image_array is None or np.sum(image_array) == 0:
77
+ return "Please draw a digit.", None
78
+
79
+ # Process the image
80
+ original = Image.fromarray(image_array.astype("uint8"), mode="L")
81
+ processed = original.resize((28, 28), Image.LANCZOS)
82
+ processed = ImageOps.invert(processed)
83
+
84
+ # Create a copy for display (enlarged for visibility)
85
+ display_img = processed.resize((140, 140), Image.NEAREST)
86
+
87
+ # Convert to model input format
88
+ input_array = np.array(processed).astype("float32") / 255.0
89
+ input_array = input_array.reshape(1, 28, 28, 1)
90
+
91
+ # Make prediction
92
+ logits = model.predict(input_array, verbose=0)
93
+ prediction = int(np.argmax(logits))
94
+ confidence = float(tf.nn.softmax(logits)[0][prediction])
95
+
96
+ return f"Digit: {prediction} (confidence: {confidence:.2%})", display_img
97
+ except Exception as err:
98
+ return f"Runtime error: {str(err)}", None
99
+
100
+ # Create interface with image output for debugging
101
+ gr.Interface(
102
+ fn=predict,
103
+ inputs=gr.Sketchpad(
104
+ image_mode="L",
105
+ canvas_size=(280, 280),
106
+ type="numpy",
107
+ brush_radius=5,
108
+ ),
109
+ outputs=[
110
+ "text",
111
+ gr.Image(type="pil", label="Processed Image (28x28)")
112
+ ],
113
+ title="EMNIST Digit Classifier with Debug View",
114
+ description="Draw a digit (0-9). The right panel shows the actual 28x28 image being fed to the model."
115
+ ).launch()
116
+ This version adds a visual debug output so you can see exactly what the model is receiving. This should help us diagnose why the accuracy is low. If the processed image doesn't look like a clear digit, we'll need to adjust the preprocessing steps.RetryClaude can make mistakes. Please double-check responses. 3.7 Sonnet