ASL-TFLite-Edge / tflite_inference.py
ColdSlim's picture
Upload folder using huggingface_hub
27e248a verified
import numpy as np
import pandas as pd
import json
import tensorflow as tf
import mediapipe as mp
from skimage.transform import resize
import matplotlib.pyplot as plt
from mediapipe.framework.formats import landmark_pb2
from PIL import Image
# Load selected columns for inference
with open("inference_args.json", "r") as f:
SEL_COLS = json.load(f)["selected_columns"]
# Load TFLite model
interpreter = tf.lite.Interpreter(model_path="asl_model.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Drawing utilities
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_hands = mp.solutions.hands
def load_relevant_data_subset(pq_path):
return pd.read_parquet(pq_path, columns=SEL_COLS)
def draw_hand_landmarks(seq_df):
images = []
for seq_idx in range(len(seq_df)):
x_hand = seq_df.iloc[seq_idx].filter(regex="x_right_hand.*").values
y_hand = seq_df.iloc[seq_idx].filter(regex="y_right_hand.*").values
z_hand = seq_df.iloc[seq_idx].filter(regex="z_right_hand.*").values
right_hand_image = np.zeros((600, 600, 3))
right_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
for x, y, z in zip(x_hand, y_hand, z_hand):
right_hand_landmarks.landmark.add(x=x, y=y, z=z)
mp_drawing.draw_landmarks(
right_hand_image,
right_hand_landmarks,
mp_hands.HAND_CONNECTIONS,
landmark_drawing_spec=mp_drawing_styles.get_default_hand_landmarks_style()
)
images.append(right_hand_image)
return images
def preprocess_image(image):
img = resize(image, (64, 64), preserve_range=True).astype(np.float32) / 255.0
return np.expand_dims(img, axis=0)
def predict_from_parquet(parquet_path):
df = load_relevant_data_subset(parquet_path)
image_seq = draw_hand_landmarks(df)
if not image_seq:
raise ValueError("No hand image generated.")
img = preprocess_image(image_seq[len(image_seq) // 2])
interpreter.set_tensor(input_details[0]['index'], img)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]['index'])
prediction = np.argmax(output)
return prediction
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python tflite_inference.py <parquet_file_path>")
else:
parquet_file = sys.argv[1]
pred = predict_from_parquet(parquet_file)
print("Predicted class index:", pred)