{ "cells": [ { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import json\n", "import pandas as pd \n", "import ipdb\n", "import os\n", "import pickle as pkl\n", "import os.path as osp\n", "import numpy as np\n", "from PIL import Image\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total images found: 9642\n", "Images after KP filtering: 3895\n", "['hake_train2015_HICO_train2015_00005476.jpg', 'hake_train2015_HICO_train2015_00008329.jpg', 'hake_train2015_HICO_train2015_00008027.jpg', 'hake_train2015_HICO_train2015_00013408.jpg', 'hake_train2015_HICO_train2015_00010656.jpg']\n" ] } ], "source": [ "# Load Agniv VITpose-base hico filtering\n", "filter_path = './agniv_pose_filter/hico.npy'\n", "pose_md = np.load(filter_path, allow_pickle=True)\n", "pose_md = pose_md.item()\n", "filter_img_names = {}\n", "\n", "print(f'Total images found: {len(pose_md)}')\n", "\n", "# Filter out images with < 10 visible keypoints \n", "kp_thresh = 10\n", "\n", "for imgname, pose_num in pose_md.items():\n", " if pose_num > kp_thresh:\n", " filter_img_names[imgname] = pose_num\n", " \n", "print(f'Images after KP filtering: {len(filter_img_names)}')\n", "\n", "print(list(filter_img_names.keys())[:5])\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# # Load Agniv VITpose-base hot dict\n", "# filter_path = \"./agniv_pose_filter/hot_dict.pkl\"\n", "# with open(filter_path, 'rb') as f:\n", "# pose_md_dict = pkl.load(f)\n", " \n", "# hico_dict = {}\n", "\n", "# for k, v in pose_md_dict.items():\n", "# if 'hake' in k:\n", "# hico_dict[k] = v\n", " \n", "# print(f'Total images found: {len(hico_dict)}')\n", "\n", "# # Filter out images with < 10 visible keypoints \n", "# kp_thresh = 10\n", "\n", "# filter_img_names = {}\n", "\n", "# for imgname, kp_md in hico_dict.items():\n", "# if kp_md == 0:\n", "# continue\n", "# if kp_md[\"num_persons\"] == 1 and kp_md[\"num_kpt\"][0.5][0] > kp_thresh:\n", "# filter_img_names[imgname] = kp_md[\"num_kpt\"][0.5][0]\n", " \n", "# print(f'Images after KP filtering: {len(filter_img_names)}')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0 1 2\n", "0 1 airplane board\n", "1 2 airplane direct\n", "2 3 airplane exit\n", "3 4 airplane fly\n", "4 5 airplane inspect\n", ".. ... ... ...\n", "595 596 zebra feed\n", "596 597 zebra hold\n", "597 598 zebra pet\n", "598 599 zebra watch\n", "599 600 zebra no_interaction\n", "\n", "[600 rows x 3 columns]\n" ] } ], "source": [ "\n", "hico_annot = json.load(open('/ps/project/datasets/HICO/hico-image-level/hico-training-set-image-level.json', 'rb'))\n", "hoi_mapping = pd.read_csv('/ps/project/datasets/HICO/hico-image-level/hico_hoi_list.txt', header=None, delim_whitespace=True)\n", "print(hoi_mapping)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Final number of images 3154\n" ] } ], "source": [ "version = '1'\n", "out_dir = f'./filtered_data/v_{version}'\n", "os.makedirs(out_dir, exist_ok=True)\n", "\n", "objectwise_img_names = {}\n", "imgwise_object_names = {}\n", "img_dir = '/ps/project/datasets/HICO/hico_20150920/images/train2015'\n", "\n", "bad_object_names = ['bear', 'bird', 'cat', 'cow', \n", " 'dog', 'elephant', 'giraffe', 'horse', \n", " 'mouse', 'person', 'sheep', 'zebra' ]\n", "bad_action_names = ['buy', 'chase', 'direct', 'greet', 'herd', 'hose',\n", " 'hug', 'hunt', 'milk', 'no_interaction', 'pet', 'point', 'teach',\n", " 'watch', 'wave']\n", "\n", "for i, (img_name, img_md) in enumerate(hico_annot.items()):\n", " \n", " # Apply keypoint number filtering on the images \n", " full_img_name = 'hake_train2015_' + img_name\n", " if full_img_name not in filter_img_names.keys():\n", " continue\n", " \n", " # show the image\n", " if i < 0:\n", " img = Image.open(osp.join(img_dir,img_name))\n", " display(img)\n", " \n", " obj_names = []\n", " action_names = []\n", " kp_num = filter_img_names[full_img_name]\n", " \n", " # travel through all hoi in the metadata, save obj_names and action_names for the hois\n", " for hoi_id in img_md['hoi_id']:\n", " img_md_row = hoi_mapping.loc[hoi_mapping[0] == hoi_id].iloc[0]\n", "\n", " obj_name = img_md_row[1]\n", " if obj_name in bad_object_names:\n", " continue\n", " action_name = img_md_row[2]\n", " if action_name in bad_action_names:\n", " continue\n", " \n", " obj_names.append(obj_name)\n", " action_names.append(action_name)\n", " \n", " if len(set(obj_names)) == 0 or len(action_names) == 0:\n", " continue\n", " \n", " imgwise_object_names.setdefault(full_img_name,[]).extend(list(set(obj_names)))\n", " \n", "# # # Display images with multiple objects\n", "# if len(set(obj_names)) > 1:\n", "# print(img_name)\n", "# print(obj_names)\n", "# print(action_names)\n", "# print(f'Number of Kps: {kp_num}')\n", "# img = Image.open(osp.join(img_dir,img_name))\n", "# display(img)\n", " \n", " for obj_name in set(obj_names):\n", " objectwise_img_names.setdefault(obj_name,[]).append(full_img_name)\n", " \n", "print(f'Final number of images {len(imgwise_object_names)}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Save the imagewise_object dict\n", "out_path = osp.join(out_dir, 'object_per_image_dict.json')\n", "with open(out_path, 'w') as fp:\n", " json.dump(imgwise_object_names, fp)\n", " print(f'saved at {out_path}')\n", " \n", "# # save image_list \n", "# out_path = osp.join(out_dir, 'hico_imglist_all_140223.txt')\n", "# with open(out_path, 'w') as f:\n", "# f.write('\\n'.join(imgwise_object_names.keys()))\n", "# print(f'saved at {out_path}')\n", "\n", "\n", "# Save the object_wise dict\n", "out_path = osp.join(out_dir, 'imgnames_per_object_dict.json')\n", "with open(out_path, 'w') as fp:\n", " json.dump(objectwise_img_names, fp)\n", " print(f'saved at {out_path}')\n", "\n", " \n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "# sort the dictionary \n", "objectwise_img_names = dict(sorted(objectwise_img_names.items(), key=lambda x: len(x[1]), reverse=True))\n", "\n", "# Extract object names and image counts\n", "obj_names = list(objectwise_img_names.keys())\n", "img_counts = [len(objectwise_img_names[obj]) for obj in objectwise_img_names]\n", "print(sum(img_counts))\n", "\n", "# Create bar plot\n", "sns.barplot(x=obj_names, y=img_counts)\n", "\n", "# Add x-axis and y-axis labels\n", "plt.xlabel('Object')\n", "plt.ylabel('Number of Images')\n", "\n", "plt.xticks(rotation=45, ha='right', fontsize=3)\n", "\n", "# Save the plot as a high-resolution image file\n", "out_path = osp.join(out_dir, 'image_per_object_category.png')\n", "plt.savefig(out_path, dpi=300)\n", "\n", "# Show plot\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }