import numpy as np import pandas as pd import json import tensorflow as tf import mediapipe as mp from skimage.transform import resize import matplotlib.pyplot as plt from mediapipe.framework.formats import landmark_pb2 from PIL import Image # Load selected columns for inference with open("inference_args.json", "r") as f: SEL_COLS = json.load(f)["selected_columns"] # Load TFLite model interpreter = tf.lite.Interpreter(model_path="asl_model.tflite") interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # Drawing utilities mp_drawing = mp.solutions.drawing_utils mp_drawing_styles = mp.solutions.drawing_styles mp_hands = mp.solutions.hands def load_relevant_data_subset(pq_path): return pd.read_parquet(pq_path, columns=SEL_COLS) def draw_hand_landmarks(seq_df): images = [] for seq_idx in range(len(seq_df)): x_hand = seq_df.iloc[seq_idx].filter(regex="x_right_hand.*").values y_hand = seq_df.iloc[seq_idx].filter(regex="y_right_hand.*").values z_hand = seq_df.iloc[seq_idx].filter(regex="z_right_hand.*").values right_hand_image = np.zeros((600, 600, 3)) right_hand_landmarks = landmark_pb2.NormalizedLandmarkList() for x, y, z in zip(x_hand, y_hand, z_hand): right_hand_landmarks.landmark.add(x=x, y=y, z=z) mp_drawing.draw_landmarks( right_hand_image, right_hand_landmarks, mp_hands.HAND_CONNECTIONS, landmark_drawing_spec=mp_drawing_styles.get_default_hand_landmarks_style() ) images.append(right_hand_image) return images def preprocess_image(image): img = resize(image, (64, 64), preserve_range=True).astype(np.float32) / 255.0 return np.expand_dims(img, axis=0) def predict_from_parquet(parquet_path): df = load_relevant_data_subset(parquet_path) image_seq = draw_hand_landmarks(df) if not image_seq: raise ValueError("No hand image generated.") img = preprocess_image(image_seq[len(image_seq) // 2]) interpreter.set_tensor(input_details[0]['index'], img) interpreter.invoke() output = interpreter.get_tensor(output_details[0]['index']) prediction = np.argmax(output) return prediction if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python tflite_inference.py ") else: parquet_file = sys.argv[1] pred = predict_from_parquet(parquet_file) print("Predicted class index:", pred)