|
import numpy as np
|
|
import pandas as pd
|
|
import json
|
|
import tensorflow as tf
|
|
import mediapipe as mp
|
|
from skimage.transform import resize
|
|
import matplotlib.pyplot as plt
|
|
from mediapipe.framework.formats import landmark_pb2
|
|
from PIL import Image
|
|
|
|
|
|
with open("inference_args.json", "r") as f:
|
|
SEL_COLS = json.load(f)["selected_columns"]
|
|
|
|
|
|
interpreter = tf.lite.Interpreter(model_path="asl_model.tflite")
|
|
interpreter.allocate_tensors()
|
|
input_details = interpreter.get_input_details()
|
|
output_details = interpreter.get_output_details()
|
|
|
|
|
|
mp_drawing = mp.solutions.drawing_utils
|
|
mp_drawing_styles = mp.solutions.drawing_styles
|
|
mp_hands = mp.solutions.hands
|
|
|
|
def load_relevant_data_subset(pq_path):
|
|
return pd.read_parquet(pq_path, columns=SEL_COLS)
|
|
|
|
def draw_hand_landmarks(seq_df):
|
|
images = []
|
|
for seq_idx in range(len(seq_df)):
|
|
x_hand = seq_df.iloc[seq_idx].filter(regex="x_right_hand.*").values
|
|
y_hand = seq_df.iloc[seq_idx].filter(regex="y_right_hand.*").values
|
|
z_hand = seq_df.iloc[seq_idx].filter(regex="z_right_hand.*").values
|
|
|
|
right_hand_image = np.zeros((600, 600, 3))
|
|
right_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
|
|
|
|
for x, y, z in zip(x_hand, y_hand, z_hand):
|
|
right_hand_landmarks.landmark.add(x=x, y=y, z=z)
|
|
|
|
mp_drawing.draw_landmarks(
|
|
right_hand_image,
|
|
right_hand_landmarks,
|
|
mp_hands.HAND_CONNECTIONS,
|
|
landmark_drawing_spec=mp_drawing_styles.get_default_hand_landmarks_style()
|
|
)
|
|
images.append(right_hand_image)
|
|
return images
|
|
|
|
def preprocess_image(image):
|
|
img = resize(image, (64, 64), preserve_range=True).astype(np.float32) / 255.0
|
|
return np.expand_dims(img, axis=0)
|
|
|
|
def predict_from_parquet(parquet_path):
|
|
df = load_relevant_data_subset(parquet_path)
|
|
image_seq = draw_hand_landmarks(df)
|
|
if not image_seq:
|
|
raise ValueError("No hand image generated.")
|
|
img = preprocess_image(image_seq[len(image_seq) // 2])
|
|
interpreter.set_tensor(input_details[0]['index'], img)
|
|
interpreter.invoke()
|
|
output = interpreter.get_tensor(output_details[0]['index'])
|
|
prediction = np.argmax(output)
|
|
return prediction
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
if len(sys.argv) < 2:
|
|
print("Usage: python tflite_inference.py <parquet_file_path>")
|
|
else:
|
|
parquet_file = sys.argv[1]
|
|
pred = predict_from_parquet(parquet_file)
|
|
print("Predicted class index:", pred)
|
|
|