Spaces:
Sleeping
Sleeping
File size: 44,603 Bytes
99a05f0 |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import pandas as pd \n",
"import ipdb\n",
"import os\n",
"import pickle as pkl\n",
"import os.path as osp\n",
"import numpy as np\n",
"from PIL import Image\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total images found: 9642\n",
"Images after KP filtering: 3895\n",
"['hake_train2015_HICO_train2015_00005476.jpg', 'hake_train2015_HICO_train2015_00008329.jpg', 'hake_train2015_HICO_train2015_00008027.jpg', 'hake_train2015_HICO_train2015_00013408.jpg', 'hake_train2015_HICO_train2015_00010656.jpg']\n"
]
}
],
"source": [
"# Load Agniv VITpose-base hico filtering\n",
"filter_path = './agniv_pose_filter/hico.npy'\n",
"pose_md = np.load(filter_path, allow_pickle=True)\n",
"pose_md = pose_md.item()\n",
"filter_img_names = {}\n",
"\n",
"print(f'Total images found: {len(pose_md)}')\n",
"\n",
"# Filter out images with < 10 visible keypoints \n",
"kp_thresh = 10\n",
"\n",
"for imgname, pose_num in pose_md.items():\n",
" if pose_num > kp_thresh:\n",
" filter_img_names[imgname] = pose_num\n",
" \n",
"print(f'Images after KP filtering: {len(filter_img_names)}')\n",
"\n",
"print(list(filter_img_names.keys())[:5])\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# # Load Agniv VITpose-base hot dict\n",
"# filter_path = \"./agniv_pose_filter/hot_dict.pkl\"\n",
"# with open(filter_path, 'rb') as f:\n",
"# pose_md_dict = pkl.load(f)\n",
" \n",
"# hico_dict = {}\n",
"\n",
"# for k, v in pose_md_dict.items():\n",
"# if 'hake' in k:\n",
"# hico_dict[k] = v\n",
" \n",
"# print(f'Total images found: {len(hico_dict)}')\n",
"\n",
"# # Filter out images with < 10 visible keypoints \n",
"# kp_thresh = 10\n",
"\n",
"# filter_img_names = {}\n",
"\n",
"# for imgname, kp_md in hico_dict.items():\n",
"# if kp_md == 0:\n",
"# continue\n",
"# if kp_md[\"num_persons\"] == 1 and kp_md[\"num_kpt\"][0.5][0] > kp_thresh:\n",
"# filter_img_names[imgname] = kp_md[\"num_kpt\"][0.5][0]\n",
" \n",
"# print(f'Images after KP filtering: {len(filter_img_names)}')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0 1 2\n",
"0 1 airplane board\n",
"1 2 airplane direct\n",
"2 3 airplane exit\n",
"3 4 airplane fly\n",
"4 5 airplane inspect\n",
".. ... ... ...\n",
"595 596 zebra feed\n",
"596 597 zebra hold\n",
"597 598 zebra pet\n",
"598 599 zebra watch\n",
"599 600 zebra no_interaction\n",
"\n",
"[600 rows x 3 columns]\n"
]
}
],
"source": [
"\n",
"hico_annot = json.load(open('/ps/project/datasets/HICO/hico-image-level/hico-training-set-image-level.json', 'rb'))\n",
"hoi_mapping = pd.read_csv('/ps/project/datasets/HICO/hico-image-level/hico_hoi_list.txt', header=None, delim_whitespace=True)\n",
"print(hoi_mapping)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final number of images 3154\n"
]
}
],
"source": [
"version = '1'\n",
"out_dir = f'./filtered_data/v_{version}'\n",
"os.makedirs(out_dir, exist_ok=True)\n",
"\n",
"objectwise_img_names = {}\n",
"imgwise_object_names = {}\n",
"img_dir = '/ps/project/datasets/HICO/hico_20150920/images/train2015'\n",
"\n",
"bad_object_names = ['bear', 'bird', 'cat', 'cow', \n",
" 'dog', 'elephant', 'giraffe', 'horse', \n",
" 'mouse', 'person', 'sheep', 'zebra' ]\n",
"bad_action_names = ['buy', 'chase', 'direct', 'greet', 'herd', 'hose',\n",
" 'hug', 'hunt', 'milk', 'no_interaction', 'pet', 'point', 'teach',\n",
" 'watch', 'wave']\n",
"\n",
"for i, (img_name, img_md) in enumerate(hico_annot.items()):\n",
" \n",
" # Apply keypoint number filtering on the images \n",
" full_img_name = 'hake_train2015_' + img_name\n",
" if full_img_name not in filter_img_names.keys():\n",
" continue\n",
" \n",
" # show the image\n",
" if i < 0:\n",
" img = Image.open(osp.join(img_dir,img_name))\n",
" display(img)\n",
" \n",
" obj_names = []\n",
" action_names = []\n",
" kp_num = filter_img_names[full_img_name]\n",
" \n",
" # travel through all hoi in the metadata, save obj_names and action_names for the hois\n",
" for hoi_id in img_md['hoi_id']:\n",
" img_md_row = hoi_mapping.loc[hoi_mapping[0] == hoi_id].iloc[0]\n",
"\n",
" obj_name = img_md_row[1]\n",
" if obj_name in bad_object_names:\n",
" continue\n",
" action_name = img_md_row[2]\n",
" if action_name in bad_action_names:\n",
" continue\n",
" \n",
" obj_names.append(obj_name)\n",
" action_names.append(action_name)\n",
" \n",
" if len(set(obj_names)) == 0 or len(action_names) == 0:\n",
" continue\n",
" \n",
" imgwise_object_names.setdefault(full_img_name,[]).extend(list(set(obj_names)))\n",
" \n",
"# # # Display images with multiple objects\n",
"# if len(set(obj_names)) > 1:\n",
"# print(img_name)\n",
"# print(obj_names)\n",
"# print(action_names)\n",
"# print(f'Number of Kps: {kp_num}')\n",
"# img = Image.open(osp.join(img_dir,img_name))\n",
"# display(img)\n",
" \n",
" for obj_name in set(obj_names):\n",
" objectwise_img_names.setdefault(obj_name,[]).append(full_img_name)\n",
" \n",
"print(f'Final number of images {len(imgwise_object_names)}')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"saved at ./filtered_data/v_1/object_per_image_dict.json\n",
"saved at ./filtered_data/v_1/imgnames_per_object_dict.json\n"
]
}
],
"source": [
"# Save the imagewise_object dict\n",
"out_path = osp.join(out_dir, 'object_per_image_dict.json')\n",
"with open(out_path, 'w') as fp:\n",
" json.dump(imgwise_object_names, fp)\n",
" print(f'saved at {out_path}')\n",
" \n",
"# # save image_list \n",
"# out_path = osp.join(out_dir, 'hico_imglist_all_140223.txt')\n",
"# with open(out_path, 'w') as f:\n",
"# f.write('\\n'.join(imgwise_object_names.keys()))\n",
"# print(f'saved at {out_path}')\n",
"\n",
"\n",
"# Save the object_wise dict\n",
"out_path = osp.join(out_dir, 'imgnames_per_object_dict.json')\n",
"with open(out_path, 'w') as fp:\n",
" json.dump(objectwise_img_names, fp)\n",
" print(f'saved at {out_path}')\n",
"\n",
" \n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/stripathi/anaconda3/envs/cliff/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"3189\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"# sort the dictionary \n",
"objectwise_img_names = dict(sorted(objectwise_img_names.items(), key=lambda x: len(x[1]), reverse=True))\n",
"\n",
"# Extract object names and image counts\n",
"obj_names = list(objectwise_img_names.keys())\n",
"img_counts = [len(objectwise_img_names[obj]) for obj in objectwise_img_names]\n",
"print(sum(img_counts))\n",
"\n",
"# Create bar plot\n",
"sns.barplot(x=obj_names, y=img_counts)\n",
"\n",
"# Add x-axis and y-axis labels\n",
"plt.xlabel('Object')\n",
"plt.ylabel('Number of Images')\n",
"\n",
"plt.xticks(rotation=45, ha='right', fontsize=3)\n",
"\n",
"# Save the plot as a high-resolution image file\n",
"out_path = osp.join(out_dir, 'image_per_object_category.png')\n",
"plt.savefig(out_path, dpi=300)\n",
"\n",
"# Show plot\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|