sushmit00 commited on
Commit
34f6d55
·
verified ·
1 Parent(s): 8560a15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -101
app.py CHANGED
@@ -1,15 +1,17 @@
1
  import os
2
- import subprocess # Using subprocess for better command execution
3
- import time # For potential retries or delays if needed
 
 
 
4
 
5
- # --- Dependency Installation ---
6
- # Note: Ideally, manage dependencies via requirements.txt for HF Spaces build process.
7
- # Running pip here might be slow on startup or cause unexpected version conflicts.
8
- print("Ensuring dependencies are installed...")
9
- # os.system("pip install gradio==2.4.6") # Prefer version from requirements.txt
10
- # os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") # Prefer from requirements.txt
11
 
12
- import gradio as gr # Import after potential install
13
 
14
  # --- Repository Setup ---
15
  repo_dir = "bizarre-pose-estimator"
@@ -21,18 +23,21 @@ if not os.path.exists(repo_dir):
21
  print("Repository cloned successfully.")
22
  except subprocess.CalledProcessError as e:
23
  print(f"ERROR: Failed to clone repository: {e.stderr}")
24
- # Consider exiting if clone fails
25
  exit()
26
  else:
27
  print(f"Repository '{repo_dir}' already exists.")
28
 
29
  # --- Change Directory ---
30
  try:
31
- os.chdir(repo_dir)
32
- print(f"Changed directory to: {os.getcwd()}")
 
 
 
 
33
  except FileNotFoundError:
34
  print(f"ERROR: Failed to change directory to '{repo_dir}'.")
35
- exit() # Cannot continue without the repo code
36
 
37
  # --- Download Example Image ---
38
  example_img_url = "https://i.imgur.com/IkJzlaE.jpeg"
@@ -47,21 +52,17 @@ if not os.path.exists(example_img_file):
47
 
48
  # --- Model Download ---
49
  model_zip_file = "bizarre_pose_models.zip"
50
- model_extract_dir = "bizarre_pose_models" # Expected folder name inside zip
51
- model_id = "17N5PutpYJTlKuNB6bdDaiQsPSIkYtiPm" # The correct FILE ID
52
  download_successful = False
53
-
54
- # Check if model files already exist (e.g., from previous run/unzip)
55
- # Check for a key file/folder that should exist after copy
56
- # Using the checkpoint file as an indicator
57
  model_checkpoint_rel_path = "_train/character_pose_estim/runs/feat_concat+data.ckpt"
 
58
  if os.path.exists(model_checkpoint_rel_path):
59
  print("Model checkpoint already exists. Skipping download and unzip.")
60
- download_successful = True # Treat as success for logic flow
61
- models_ready = True
62
  else:
63
- print(f"Model checkpoint not found at '{model_checkpoint_rel_path}'. Attempting download...")
64
- models_ready = False
65
  if os.path.exists(model_zip_file):
66
  print(f"Zip file '{model_zip_file}' already exists. Skipping download.")
67
  download_successful = True
@@ -70,26 +71,21 @@ else:
70
  gdown_command = f"gdown --id {model_id} -O {model_zip_file}"
71
  print(f"Executing: {gdown_command}")
72
  try:
73
- result = subprocess.run(gdown_command, shell=True, check=True, capture_output=True, text=True)
 
74
  print("Gdown download successful.")
75
- # print("stdout:", result.stdout) # Often too verbose
76
- # print("stderr:", result.stderr)
77
  download_successful = True
78
- except FileNotFoundError:
79
- print("ERROR: 'gdown' command not found. Install via requirements.txt.")
80
  except subprocess.CalledProcessError as e:
81
- print(f"ERROR: gdown download failed with exit code {e.returncode}")
82
- print("Stderr:", e.stderr)
83
- print("Stdout:", e.stdout) # May contain useful info from gdown
84
- print("Download might fail due to quotas or permissions on Google Drive.")
85
- print("Consider manually uploading 'bizarre_pose_models.zip' to the Space root if errors persist.")
86
 
87
  # --- Unzip and Copy ---
88
  if download_successful and os.path.exists(model_zip_file):
89
  print(f"Unzipping '{model_zip_file}'...")
90
  try:
91
  if os.path.exists(model_extract_dir): print(f"Note: '{model_extract_dir}' already exists. Overwriting.")
92
- subprocess.run(["unzip", "-oq", model_zip_file], check=True) # -o overwrite, -q quiet
93
  print("Unzip successful.")
94
 
95
  print("Copying model files...")
@@ -100,23 +96,16 @@ else:
100
  print("Model files copied.")
101
  models_ready = True # Models should now be ready
102
 
103
- # --- Delete Zip File ---
104
- print(f"Attempting to remove zip file to save space: {model_zip_file}")
105
  try:
106
  os.remove(model_zip_file)
107
  print(f"Removed zip file: {model_zip_file}")
108
- except OSError as e:
109
- print(f"Warning: Error removing zip file {model_zip_file}: {e}")
110
- # --- End Zip Deletion ---
111
-
112
- else:
113
- print(f"ERROR: Directory '{source_dir_unzip}' not found after unzip.")
114
-
115
  except FileNotFoundError: print(f"ERROR: 'unzip' command not available.")
116
  except subprocess.CalledProcessError as e: print(f"ERROR: Unzip failed: {e}")
117
- else:
118
- if not download_successful: print("Download failed previously, cannot unzip.")
119
- elif not os.path.exists(model_zip_file): print(f"Zip file '{model_zip_file}' not found, cannot unzip.")
120
 
121
  # --- Final Check and LS ---
122
  print("\nCurrent directory contents:")
@@ -126,78 +115,60 @@ if not models_ready:
126
 
127
  # --- Gradio Interface ---
128
  def inference(img_input):
129
- # Gradio passes a PIL Image or Numpy array depending on config, need filepath
130
- # Save the uploaded image temporarily if it's not already a path
131
- if isinstance(img_input, str):
132
- img_path = img_input # Already a path (e.g., from examples)
133
  else:
134
- # Save PIL/Numpy input to a temporary file
135
  from PIL import Image
136
  import numpy as np
137
- if isinstance(img_input, np.ndarray):
138
- img_pil = Image.fromarray(img_input)
139
- else: # Assume PIL Image
140
- img_pil = img_input
141
- temp_img_name = "temp_gradio_input.png"
142
- img_pil.save(temp_img_name)
143
- img_path = temp_img_name
144
  print(f"Saved Gradio input to temporary file: {img_path}")
145
 
146
  print(f"Running inference on: {img_path}")
147
  output_sample_path = "./_samples/character_pose_estim.png"
148
  model_checkpoint = "./_train/character_pose_estim/runs/feat_concat+data.ckpt"
149
 
150
- # Check paths relative to current directory ('bizarre-pose-estimator')
151
- if not os.path.exists(img_path):
152
- print(f"ERROR: Input image path does not exist: {img_path}"); return None
153
- if not os.path.exists(model_checkpoint):
154
- print(f"ERROR: Model checkpoint not found at: {model_checkpoint}"); return None
155
 
156
- command = f"python3 -m _scripts.pose_estimator '{img_path}' '{model_checkpoint}'" # Add quotes for paths
157
  print(f"Executing command: {command}")
158
  try:
159
- result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True, timeout=120) # Add timeout
160
- print("Inference stdout:", result.stdout)
161
- print("Inference stderr:", result.stderr)
162
- print("Inference completed.")
163
- if os.path.exists(output_sample_path):
164
- return output_sample_path # Return path to the output image
165
- else:
166
- print(f"ERROR: Inference ran but output file not found: {output_sample_path}")
167
- return None
168
- except subprocess.TimeoutExpired:
169
- print("ERROR: Inference script timed out.")
170
- return None
171
- except subprocess.CalledProcessError as e:
172
- print(f"ERROR: Inference script failed: {e.returncode}\nStderr:{e.stderr}\nStdout:{e.stdout}")
173
- return None
174
  finally:
175
- # Clean up temporary input file if created
176
- if not isinstance(img_input, str) and os.path.exists(temp_img_name):
177
- try:
178
- os.remove(temp_img_name)
179
- print(f"Removed temporary file: {temp_img_name}")
180
- except OSError as e:
181
- print(f"Warning: Failed to remove temp file {temp_img_name}: {e}")
182
 
183
 
184
  title = "bizarre-pose-estimator"
185
  description = "Gradio demo for Transfer Learning for Pose Estimation of Illustrated Characters. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
186
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.01819' target='_blank'>Transfer Learning for Pose Estimation of Illustrated Characters</a> | <a href='https://github.com/ShuhongChen/bizarre-pose-estimator' target='_blank'>Github Repo</a></p>"
187
-
188
- # Ensure example file exists before adding it
189
  examples_list = [[example_img_file]] if os.path.exists(example_img_file) else []
190
 
191
- gr.Interface(
192
- fn=inference, # Use fn= alias
193
- # Input needs to accept various types, filepath ensures it works with examples
194
- inputs=gr.Image(type="pil", label="Input"), # Use PIL input type, save internally
195
- outputs=gr.Image(type="filepath", label="Output"), # Output is path to generated file
196
- title=title,
197
- description=description,
198
- article=article,
199
- allow_flagging="never", # Original setting
200
- examples=examples_list,
201
- enable_queue=True # Original setting
202
- # Consider adding cache_examples=True if examples are static
203
- ).launch() # share=True is needed only if not running on HF platform
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import subprocess
3
+ import time
4
+ # REMOVE pip installs from here - manage via requirements.txt
5
+ # os.system("pip install gradio==2.4.6") # REMOVE
6
+ # os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") # REMOVE
7
 
8
+ # Import gradio AFTER potentially installing it (though it should already be there)
9
+ try:
10
+ import gradio as gr
11
+ except ImportError:
12
+ print("ERROR: Gradio not installed. Ensure it's in requirements.txt")
13
+ exit()
14
 
 
15
 
16
  # --- Repository Setup ---
17
  repo_dir = "bizarre-pose-estimator"
 
23
  print("Repository cloned successfully.")
24
  except subprocess.CalledProcessError as e:
25
  print(f"ERROR: Failed to clone repository: {e.stderr}")
 
26
  exit()
27
  else:
28
  print(f"Repository '{repo_dir}' already exists.")
29
 
30
  # --- Change Directory ---
31
  try:
32
+ # Check if already in the directory (important for restarts)
33
+ if os.path.basename(os.getcwd()) != repo_dir:
34
+ os.chdir(repo_dir)
35
+ print(f"Changed directory to: {os.getcwd()}")
36
+ else:
37
+ print(f"Already in directory: {os.getcwd()}")
38
  except FileNotFoundError:
39
  print(f"ERROR: Failed to change directory to '{repo_dir}'.")
40
+ exit()
41
 
42
  # --- Download Example Image ---
43
  example_img_url = "https://i.imgur.com/IkJzlaE.jpeg"
 
52
 
53
  # --- Model Download ---
54
  model_zip_file = "bizarre_pose_models.zip"
55
+ model_extract_dir = "bizarre_pose_models"
56
+ model_id = "17N5PutpYJTlKuNB6bdDaiQsPSIkYtiPm"
57
  download_successful = False
58
+ models_ready = False
 
 
 
59
  model_checkpoint_rel_path = "_train/character_pose_estim/runs/feat_concat+data.ckpt"
60
+
61
  if os.path.exists(model_checkpoint_rel_path):
62
  print("Model checkpoint already exists. Skipping download and unzip.")
63
+ models_ready = True # Models are already copied and ready
 
64
  else:
65
+ print(f"Model checkpoint not found at '{model_checkpoint_rel_path}'. Checking for zip...")
 
66
  if os.path.exists(model_zip_file):
67
  print(f"Zip file '{model_zip_file}' already exists. Skipping download.")
68
  download_successful = True
 
71
  gdown_command = f"gdown --id {model_id} -O {model_zip_file}"
72
  print(f"Executing: {gdown_command}")
73
  try:
74
+ # Set timeout for gdown (e.g., 5 minutes)
75
+ result = subprocess.run(gdown_command, shell=True, check=True, capture_output=True, text=True, timeout=300)
76
  print("Gdown download successful.")
 
 
77
  download_successful = True
78
+ except FileNotFoundError: print("ERROR: 'gdown' command not found.")
79
+ except subprocess.TimeoutExpired: print("ERROR: gdown download timed out.")
80
  except subprocess.CalledProcessError as e:
81
+ print(f"ERROR: gdown download failed: {e.returncode}\nStderr:{e.stderr}\nStdout:{e.stdout}")
 
 
 
 
82
 
83
  # --- Unzip and Copy ---
84
  if download_successful and os.path.exists(model_zip_file):
85
  print(f"Unzipping '{model_zip_file}'...")
86
  try:
87
  if os.path.exists(model_extract_dir): print(f"Note: '{model_extract_dir}' already exists. Overwriting.")
88
+ subprocess.run(["unzip", "-oq", model_zip_file], check=True)
89
  print("Unzip successful.")
90
 
91
  print("Copying model files...")
 
96
  print("Model files copied.")
97
  models_ready = True # Models should now be ready
98
 
99
+ print(f"Attempting to remove zip file: {model_zip_file}")
 
100
  try:
101
  os.remove(model_zip_file)
102
  print(f"Removed zip file: {model_zip_file}")
103
+ except OSError as e: print(f"Warning: Error removing zip file {model_zip_file}: {e}")
104
+ else: print(f"ERROR: Directory '{source_dir_unzip}' not found after unzip.")
 
 
 
 
 
105
  except FileNotFoundError: print(f"ERROR: 'unzip' command not available.")
106
  except subprocess.CalledProcessError as e: print(f"ERROR: Unzip failed: {e}")
107
+ elif not download_successful: print("Download failed previously, cannot unzip.")
108
+ elif not os.path.exists(model_zip_file): print(f"Zip file '{model_zip_file}' not found, cannot unzip.")
 
109
 
110
  # --- Final Check and LS ---
111
  print("\nCurrent directory contents:")
 
115
 
116
  # --- Gradio Interface ---
117
  def inference(img_input):
118
+ if isinstance(img_input, str): img_path = img_input
 
 
 
119
  else:
 
120
  from PIL import Image
121
  import numpy as np
122
+ if isinstance(img_input, np.ndarray): img_pil = Image.fromarray(img_input)
123
+ else: img_pil = img_input # Assume PIL
124
+ temp_img_name = "temp_gradio_input.png"; img_pil.save(temp_img_name); img_path = temp_img_name
 
 
 
 
125
  print(f"Saved Gradio input to temporary file: {img_path}")
126
 
127
  print(f"Running inference on: {img_path}")
128
  output_sample_path = "./_samples/character_pose_estim.png"
129
  model_checkpoint = "./_train/character_pose_estim/runs/feat_concat+data.ckpt"
130
 
131
+ if not os.path.exists(img_path): print(f"ERROR: Input path does not exist: {img_path}"); return None
132
+ if not os.path.exists(model_checkpoint): print(f"ERROR: Model checkpoint not found: {model_checkpoint}"); return None
 
 
 
133
 
134
+ command = f"python3 -m _scripts.pose_estimator '{img_path}' '{model_checkpoint}'"
135
  print(f"Executing command: {command}")
136
  try:
137
+ result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True, timeout=120)
138
+ print("Inference stdout:", result.stdout); print("Inference stderr:", result.stderr); print("Inference completed.")
139
+ if os.path.exists(output_sample_path): return output_sample_path
140
+ else: print(f"ERROR: Output file not found: {output_sample_path}"); return None
141
+ except subprocess.TimeoutExpired: print("ERROR: Inference script timed out."); return None
142
+ except subprocess.CalledProcessError as e: print(f"ERROR: Inference script failed: {e.returncode}\nStderr:{e.stderr}\nStdout:{e.stdout}"); return None
 
 
 
 
 
 
 
 
 
143
  finally:
144
+ if not isinstance(img_input, str) and 'temp_img_name' in locals() and os.path.exists(temp_img_name):
145
+ try: os.remove(temp_img_name); print(f"Removed temporary file: {temp_img_name}")
146
+ except OSError as e: print(f"Warning: Failed to remove temp file {temp_img_name}: {e}")
 
 
 
 
147
 
148
 
149
  title = "bizarre-pose-estimator"
150
  description = "Gradio demo for Transfer Learning for Pose Estimation of Illustrated Characters. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
151
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.01819' target='_blank'>Transfer Learning for Pose Estimation of Illustrated Characters</a> | <a href='https://github.com/ShuhongChen/bizarre-pose-estimator' target='_blank'>Github Repo</a></p>"
 
 
152
  examples_list = [[example_img_file]] if os.path.exists(example_img_file) else []
153
 
154
+ print("Setting up Gradio Interface...")
155
+ # Check if models are ready before launching
156
+ if models_ready:
157
+ # REMOVE enable_queue=True
158
+ # Also update inputs/outputs to newer syntax if possible, but focus on fixing the error first
159
+ # Using older gr.inputs/gr.outputs syntax for now as it might be what gradio 2.4.6 expects if it somehow overrides
160
+ ui = gr.Interface(
161
+ fn=inference,
162
+ inputs=gr.inputs.Image(type="filepath", label="Input"), # Using filepath to simplify function logic
163
+ outputs=gr.outputs.Image(type="file", label="Output"), # Function returns filepath
164
+ title=title,
165
+ description=description,
166
+ article=article,
167
+ allow_flagging="never",
168
+ examples=examples_list
169
+ # enable_queue=True # REMOVED
170
+ )
171
+ print("Launching Gradio...")
172
+ ui.launch() # queue() method is used in newer Gradio if needed, launch() handles basics
173
+ else:
174
+ print("Gradio launch aborted because model files are not ready.")