File size: 8,589 Bytes
9b671f0
34f6d55
 
 
 
 
cdc4f20
34f6d55
 
 
 
 
 
2280cb8
cdc4f20
8560a15
 
 
 
 
 
 
 
 
 
 
 
 
cdc4f20
8560a15
 
34f6d55
 
 
 
 
 
8560a15
 
34f6d55
3f7fde4
8560a15
 
 
 
 
 
 
 
 
 
3f7fde4
8560a15
 
34f6d55
 
8560a15
34f6d55
8560a15
34f6d55
8560a15
 
34f6d55
8560a15
34f6d55
8560a15
 
 
 
 
 
 
 
34f6d55
 
8560a15
 
34f6d55
 
8560a15
34f6d55
8560a15
 
 
 
 
 
34f6d55
8560a15
 
 
 
 
 
 
 
 
 
34f6d55
8560a15
 
 
34f6d55
 
8560a15
 
34f6d55
 
8560a15
 
 
 
 
 
 
 
 
34f6d55
8560a15
 
 
34f6d55
 
 
8560a15
 
 
 
 
 
34f6d55
 
8560a15
34f6d55
8560a15
 
34f6d55
 
 
 
 
 
8560a15
34f6d55
 
 
8be84fd
cdc4f20
f2d8ee6
 
4c10312
f18a136
8560a15
 
34f6d55
 
f18a136
34f6d55
 
f18a136
 
 
 
 
34f6d55
 
 
 
 
 
 
f18a136
34f6d55
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import subprocess
import time
# REMOVE pip installs from here - manage via requirements.txt
# os.system("pip install gradio==2.4.6") # REMOVE
# os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") # REMOVE

# Import gradio AFTER potentially installing it (though it should already be there)
try:
    import gradio as gr
except ImportError:
    print("ERROR: Gradio not installed. Ensure it's in requirements.txt")
    exit()


# --- Repository Setup ---
repo_dir = "bizarre-pose-estimator"
if not os.path.exists(repo_dir):
    print(f"Cloning repository '{repo_dir}'...")
    clone_command = "git clone https://github.com/ShuhongChen/bizarre-pose-estimator.git"
    try:
        subprocess.run(clone_command, shell=True, check=True, capture_output=True, text=True)
        print("Repository cloned successfully.")
    except subprocess.CalledProcessError as e:
        print(f"ERROR: Failed to clone repository: {e.stderr}")
        exit()
else:
    print(f"Repository '{repo_dir}' already exists.")

# --- Change Directory ---
try:
    # Check if already in the directory (important for restarts)
    if os.path.basename(os.getcwd()) != repo_dir:
        os.chdir(repo_dir)
        print(f"Changed directory to: {os.getcwd()}")
    else:
        print(f"Already in directory: {os.getcwd()}")
except FileNotFoundError:
    print(f"ERROR: Failed to change directory to '{repo_dir}'.")
    exit()

# --- Download Example Image ---
example_img_url = "https://i.imgur.com/IkJzlaE.jpeg"
example_img_file = "IkJzlaE.jpeg"
if not os.path.exists(example_img_file):
    print(f"Downloading example image: {example_img_file}")
    try:
        subprocess.run(["curl", "-fL", example_img_url, "-o", example_img_file], check=True)
        print("Example image downloaded.")
    except (subprocess.CalledProcessError, FileNotFoundError) as e:
        print(f"Warning: Failed to download example image {example_img_url}. Error: {e}")

# --- Model Download ---
model_zip_file = "bizarre_pose_models.zip"
model_extract_dir = "bizarre_pose_models"
model_id = "17N5PutpYJTlKuNB6bdDaiQsPSIkYtiPm"
download_successful = False
models_ready = False
model_checkpoint_rel_path = "_train/character_pose_estim/runs/feat_concat+data.ckpt"

if os.path.exists(model_checkpoint_rel_path):
    print("Model checkpoint already exists. Skipping download and unzip.")
    models_ready = True # Models are already copied and ready
else:
    print(f"Model checkpoint not found at '{model_checkpoint_rel_path}'. Checking for zip...")
    if os.path.exists(model_zip_file):
        print(f"Zip file '{model_zip_file}' already exists. Skipping download.")
        download_successful = True
    else:
        print(f"Attempting to download model weights using gdown (ID: {model_id})...")
        gdown_command = f"gdown --id {model_id} -O {model_zip_file}"
        print(f"Executing: {gdown_command}")
        try:
            # Set timeout for gdown (e.g., 5 minutes)
            result = subprocess.run(gdown_command, shell=True, check=True, capture_output=True, text=True, timeout=300)
            print("Gdown download successful.")
            download_successful = True
        except FileNotFoundError: print("ERROR: 'gdown' command not found.")
        except subprocess.TimeoutExpired: print("ERROR: gdown download timed out.")
        except subprocess.CalledProcessError as e:
            print(f"ERROR: gdown download failed: {e.returncode}\nStderr:{e.stderr}\nStdout:{e.stdout}")

    # --- Unzip and Copy ---
    if download_successful and os.path.exists(model_zip_file):
        print(f"Unzipping '{model_zip_file}'...")
        try:
            if os.path.exists(model_extract_dir): print(f"Note: '{model_extract_dir}' already exists. Overwriting.")
            subprocess.run(["unzip", "-oq", model_zip_file], check=True)
            print("Unzip successful.")

            print("Copying model files...")
            source_dir_unzip = f"./{model_extract_dir}/"
            dest_dir = "."
            if os.path.exists(source_dir_unzip):
                os.system(f"cp -a {source_dir_unzip}. {dest_dir}") # Use os.system like original
                print("Model files copied.")
                models_ready = True # Models should now be ready

                print(f"Attempting to remove zip file: {model_zip_file}")
                try:
                    os.remove(model_zip_file)
                    print(f"Removed zip file: {model_zip_file}")
                except OSError as e: print(f"Warning: Error removing zip file {model_zip_file}: {e}")
            else: print(f"ERROR: Directory '{source_dir_unzip}' not found after unzip.")
        except FileNotFoundError: print(f"ERROR: 'unzip' command not available.")
        except subprocess.CalledProcessError as e: print(f"ERROR: Unzip failed: {e}")
    elif not download_successful: print("Download failed previously, cannot unzip.")
    elif not os.path.exists(model_zip_file): print(f"Zip file '{model_zip_file}' not found, cannot unzip.")

# --- Final Check and LS ---
print("\nCurrent directory contents:")
os.system("ls -la")
if not models_ready:
    print("\n\nERROR: MODEL FILES ARE NOT SET UP CORRECTLY. APP MAY NOT WORK.\n\n")

# --- Gradio Interface ---
def inference(img_input):
    if isinstance(img_input, str): img_path = img_input
    else:
        from PIL import Image
        import numpy as np
        if isinstance(img_input, np.ndarray): img_pil = Image.fromarray(img_input)
        else: img_pil = img_input # Assume PIL
        temp_img_name = "temp_gradio_input.png"; img_pil.save(temp_img_name); img_path = temp_img_name
        print(f"Saved Gradio input to temporary file: {img_path}")

    print(f"Running inference on: {img_path}")
    output_sample_path = "./_samples/character_pose_estim.png"
    model_checkpoint = "./_train/character_pose_estim/runs/feat_concat+data.ckpt"

    if not os.path.exists(img_path): print(f"ERROR: Input path does not exist: {img_path}"); return None
    if not os.path.exists(model_checkpoint): print(f"ERROR: Model checkpoint not found: {model_checkpoint}"); return None

    command = f"python3 -m _scripts.pose_estimator '{img_path}' '{model_checkpoint}'"
    print(f"Executing command: {command}")
    try:
        result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True, timeout=120)
        print("Inference stdout:", result.stdout); print("Inference stderr:", result.stderr); print("Inference completed.")
        if os.path.exists(output_sample_path): return output_sample_path
        else: print(f"ERROR: Output file not found: {output_sample_path}"); return None
    except subprocess.TimeoutExpired: print("ERROR: Inference script timed out."); return None
    except subprocess.CalledProcessError as e: print(f"ERROR: Inference script failed: {e.returncode}\nStderr:{e.stderr}\nStdout:{e.stdout}"); return None
    finally:
        if not isinstance(img_input, str) and 'temp_img_name' in locals() and os.path.exists(temp_img_name):
            try: os.remove(temp_img_name); print(f"Removed temporary file: {temp_img_name}")
            except OSError as e: print(f"Warning: Failed to remove temp file {temp_img_name}: {e}")


title = "bizarre-pose-estimator"
description = "Gradio demo for Transfer Learning for Pose Estimation of Illustrated Characters. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.01819' target='_blank'>Transfer Learning for Pose Estimation of Illustrated Characters</a> | <a href='https://github.com/ShuhongChen/bizarre-pose-estimator' target='_blank'>Github Repo</a></p>"
example_img_file = "IkJzlaE.jpeg"
examples_list = [[example_img_file]] if os.path.exists(example_img_file) else []

print("Setting up Gradio Interface...")
if models_ready:
    # --- USE NEWER GRADIO SYNTAX for inputs/outputs ---
    ui = gr.Interface(
        fn=inference,
        # Define inputs using component classes directly
        inputs=gr.Image(type="filepath", label="Input"), # Keep type="filepath"
        # Define outputs using component classes directly
        outputs=gr.Image(type="filepath", label="Output"), # Keep type="filepath"
        # --- End of syntax change ---
        title=title,
        description=description,
        article=article,
        allow_flagging="never",
        examples=examples_list
    )
    print("Launching Gradio...")
    ui.launch()
else:
    print("Gradio launch aborted because model files are not ready.")