|
import os |
|
import subprocess |
|
import time |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
import gradio as gr |
|
except ImportError: |
|
print("ERROR: Gradio not installed. Ensure it's in requirements.txt") |
|
exit() |
|
|
|
|
|
|
|
repo_dir = "bizarre-pose-estimator" |
|
if not os.path.exists(repo_dir): |
|
print(f"Cloning repository '{repo_dir}'...") |
|
clone_command = "git clone https://github.com/ShuhongChen/bizarre-pose-estimator.git" |
|
try: |
|
subprocess.run(clone_command, shell=True, check=True, capture_output=True, text=True) |
|
print("Repository cloned successfully.") |
|
except subprocess.CalledProcessError as e: |
|
print(f"ERROR: Failed to clone repository: {e.stderr}") |
|
exit() |
|
else: |
|
print(f"Repository '{repo_dir}' already exists.") |
|
|
|
|
|
try: |
|
|
|
if os.path.basename(os.getcwd()) != repo_dir: |
|
os.chdir(repo_dir) |
|
print(f"Changed directory to: {os.getcwd()}") |
|
else: |
|
print(f"Already in directory: {os.getcwd()}") |
|
except FileNotFoundError: |
|
print(f"ERROR: Failed to change directory to '{repo_dir}'.") |
|
exit() |
|
|
|
|
|
example_img_url = "https://i.imgur.com/IkJzlaE.jpeg" |
|
example_img_file = "IkJzlaE.jpeg" |
|
if not os.path.exists(example_img_file): |
|
print(f"Downloading example image: {example_img_file}") |
|
try: |
|
subprocess.run(["curl", "-fL", example_img_url, "-o", example_img_file], check=True) |
|
print("Example image downloaded.") |
|
except (subprocess.CalledProcessError, FileNotFoundError) as e: |
|
print(f"Warning: Failed to download example image {example_img_url}. Error: {e}") |
|
|
|
|
|
model_zip_file = "bizarre_pose_models.zip" |
|
model_extract_dir = "bizarre_pose_models" |
|
model_id = "17N5PutpYJTlKuNB6bdDaiQsPSIkYtiPm" |
|
download_successful = False |
|
models_ready = False |
|
model_checkpoint_rel_path = "_train/character_pose_estim/runs/feat_concat+data.ckpt" |
|
|
|
if os.path.exists(model_checkpoint_rel_path): |
|
print("Model checkpoint already exists. Skipping download and unzip.") |
|
models_ready = True |
|
else: |
|
print(f"Model checkpoint not found at '{model_checkpoint_rel_path}'. Checking for zip...") |
|
if os.path.exists(model_zip_file): |
|
print(f"Zip file '{model_zip_file}' already exists. Skipping download.") |
|
download_successful = True |
|
else: |
|
print(f"Attempting to download model weights using gdown (ID: {model_id})...") |
|
gdown_command = f"gdown --id {model_id} -O {model_zip_file}" |
|
print(f"Executing: {gdown_command}") |
|
try: |
|
|
|
result = subprocess.run(gdown_command, shell=True, check=True, capture_output=True, text=True, timeout=300) |
|
print("Gdown download successful.") |
|
download_successful = True |
|
except FileNotFoundError: print("ERROR: 'gdown' command not found.") |
|
except subprocess.TimeoutExpired: print("ERROR: gdown download timed out.") |
|
except subprocess.CalledProcessError as e: |
|
print(f"ERROR: gdown download failed: {e.returncode}\nStderr:{e.stderr}\nStdout:{e.stdout}") |
|
|
|
|
|
if download_successful and os.path.exists(model_zip_file): |
|
print(f"Unzipping '{model_zip_file}'...") |
|
try: |
|
if os.path.exists(model_extract_dir): print(f"Note: '{model_extract_dir}' already exists. Overwriting.") |
|
subprocess.run(["unzip", "-oq", model_zip_file], check=True) |
|
print("Unzip successful.") |
|
|
|
print("Copying model files...") |
|
source_dir_unzip = f"./{model_extract_dir}/" |
|
dest_dir = "." |
|
if os.path.exists(source_dir_unzip): |
|
os.system(f"cp -a {source_dir_unzip}. {dest_dir}") |
|
print("Model files copied.") |
|
models_ready = True |
|
|
|
print(f"Attempting to remove zip file: {model_zip_file}") |
|
try: |
|
os.remove(model_zip_file) |
|
print(f"Removed zip file: {model_zip_file}") |
|
except OSError as e: print(f"Warning: Error removing zip file {model_zip_file}: {e}") |
|
else: print(f"ERROR: Directory '{source_dir_unzip}' not found after unzip.") |
|
except FileNotFoundError: print(f"ERROR: 'unzip' command not available.") |
|
except subprocess.CalledProcessError as e: print(f"ERROR: Unzip failed: {e}") |
|
elif not download_successful: print("Download failed previously, cannot unzip.") |
|
elif not os.path.exists(model_zip_file): print(f"Zip file '{model_zip_file}' not found, cannot unzip.") |
|
|
|
|
|
print("\nCurrent directory contents:") |
|
os.system("ls -la") |
|
if not models_ready: |
|
print("\n\nERROR: MODEL FILES ARE NOT SET UP CORRECTLY. APP MAY NOT WORK.\n\n") |
|
|
|
|
|
def inference(img_input): |
|
if isinstance(img_input, str): img_path = img_input |
|
else: |
|
from PIL import Image |
|
import numpy as np |
|
if isinstance(img_input, np.ndarray): img_pil = Image.fromarray(img_input) |
|
else: img_pil = img_input |
|
temp_img_name = "temp_gradio_input.png"; img_pil.save(temp_img_name); img_path = temp_img_name |
|
print(f"Saved Gradio input to temporary file: {img_path}") |
|
|
|
print(f"Running inference on: {img_path}") |
|
output_sample_path = "./_samples/character_pose_estim.png" |
|
model_checkpoint = "./_train/character_pose_estim/runs/feat_concat+data.ckpt" |
|
|
|
if not os.path.exists(img_path): print(f"ERROR: Input path does not exist: {img_path}"); return None |
|
if not os.path.exists(model_checkpoint): print(f"ERROR: Model checkpoint not found: {model_checkpoint}"); return None |
|
|
|
command = f"python3 -m _scripts.pose_estimator '{img_path}' '{model_checkpoint}'" |
|
print(f"Executing command: {command}") |
|
try: |
|
result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True, timeout=120) |
|
print("Inference stdout:", result.stdout); print("Inference stderr:", result.stderr); print("Inference completed.") |
|
if os.path.exists(output_sample_path): return output_sample_path |
|
else: print(f"ERROR: Output file not found: {output_sample_path}"); return None |
|
except subprocess.TimeoutExpired: print("ERROR: Inference script timed out."); return None |
|
except subprocess.CalledProcessError as e: print(f"ERROR: Inference script failed: {e.returncode}\nStderr:{e.stderr}\nStdout:{e.stdout}"); return None |
|
finally: |
|
if not isinstance(img_input, str) and 'temp_img_name' in locals() and os.path.exists(temp_img_name): |
|
try: os.remove(temp_img_name); print(f"Removed temporary file: {temp_img_name}") |
|
except OSError as e: print(f"Warning: Failed to remove temp file {temp_img_name}: {e}") |
|
|
|
|
|
title = "bizarre-pose-estimator" |
|
description = "Gradio demo for Transfer Learning for Pose Estimation of Illustrated Characters. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.01819' target='_blank'>Transfer Learning for Pose Estimation of Illustrated Characters</a> | <a href='https://github.com/ShuhongChen/bizarre-pose-estimator' target='_blank'>Github Repo</a></p>" |
|
example_img_file = "IkJzlaE.jpeg" |
|
examples_list = [[example_img_file]] if os.path.exists(example_img_file) else [] |
|
|
|
print("Setting up Gradio Interface...") |
|
if models_ready: |
|
|
|
ui = gr.Interface( |
|
fn=inference, |
|
|
|
inputs=gr.Image(type="filepath", label="Input"), |
|
|
|
outputs=gr.Image(type="filepath", label="Output"), |
|
|
|
title=title, |
|
description=description, |
|
article=article, |
|
allow_flagging="never", |
|
examples=examples_list |
|
) |
|
print("Launching Gradio...") |
|
ui.launch() |
|
else: |
|
print("Gradio launch aborted because model files are not ready.") |