Spaces:
Runtime error
Runtime error
# motion_extractor.py | |
import os | |
import sys | |
import cv2 | |
import torch | |
import pickle | |
import torchvision | |
# Load the TorchScript model once at the top | |
model_path = 'nlf_l_multi_0.3.2.torchscript' | |
assert os.path.exists(model_path), f"Model file not found at {model_path}" | |
model = torch.jit.load(model_path).cuda().eval() | |
def extract_pkl_from_video(video_path): | |
output_file = "temp_motion.pkl" | |
cap = cv2.VideoCapture(video_path) | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
pose_results = { | |
'joints3d_nonparam': [], | |
} | |
with torch.inference_mode(), torch.device('cuda'): | |
frame_idx = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert frame to tensor | |
frame_tensor = torch.from_numpy(frame).cuda() | |
frame_batch = frame_tensor.unsqueeze(0).permute(0,3,1,2) | |
# Model inference | |
pred = model.detect_smpl_batched(frame_batch) | |
# Collect pose data | |
for key in pose_results.keys(): | |
if key in pred: | |
#pose_results[key].append(pred[key].cpu().numpy()) | |
pose_results[key].append(pred[key]) | |
else: | |
pose_results[key].append(None) | |
frame_idx += 1 | |
cap.release() | |
# Prepare output data | |
output_data = { | |
'video_path': video_path, | |
'video_length': frame_count, | |
'video_width': video_width, | |
'video_height': video_height, | |
'pose': pose_results | |
} | |
# Save to pkl file | |
with open(output_file, 'wb') as f: | |
pickle.dump(output_data, f) | |
return output_file |