x10z commited on
Commit
0b350a4
·
verified ·
1 Parent(s): 378bb55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -35,12 +35,15 @@ pipe.unet.eval()
35
 
36
  # UI texts
37
  title = "# End-to-End Fine-Tuned GeoWizard Video"
38
- description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
 
 
39
 
40
  @spaces.GPU
41
  def predict(image: Image.Image, processing_res_choice: int):
42
  """
43
  Single-frame prediction wrapped for GPU execution.
 
44
  """
45
  with torch.no_grad():
46
  return pipe(
@@ -67,7 +70,7 @@ def on_submit_video(video_path: str, processing_res_choice: int):
67
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
68
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
69
 
70
- # Create temporary output video files
71
  tmp_depth = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
72
  tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
73
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
@@ -80,13 +83,14 @@ def on_submit_video(video_path: str, processing_res_choice: int):
80
  if not ret:
81
  break
82
 
83
- # Convert BGR to RGB PIL image
84
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
85
  pil_image = Image.fromarray(rgb)
86
 
87
  # Run prediction
88
- time_error
89
- depth_np, depth_colored, normal_np, normal_colored = predict(pil_image, processing_res_choice)
 
90
 
91
  # Write depth frame
92
  depth_frame = np.array(depth_colored)
@@ -103,6 +107,7 @@ time_error
103
  out_depth.release()
104
  out_normal.release()
105
 
 
106
  return tmp_depth.name, tmp_normal.name
107
 
108
  # Build Gradio interface
 
35
 
36
  # UI texts
37
  title = "# End-to-End Fine-Tuned GeoWizard Video"
38
+ description = """
39
+ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details.
40
+ """
41
 
42
  @spaces.GPU
43
  def predict(image: Image.Image, processing_res_choice: int):
44
  """
45
  Single-frame prediction wrapped for GPU execution.
46
+ Returns a DepthNormalPipelineOutput with attributes depth_colored and normal_colored.
47
  """
48
  with torch.no_grad():
49
  return pipe(
 
70
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
71
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
72
 
73
+ # Temporary output files
74
  tmp_depth = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
75
  tmp_normal = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
76
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
83
  if not ret:
84
  break
85
 
86
+ # Convert BGR to RGB and to PIL
87
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
88
  pil_image = Image.fromarray(rgb)
89
 
90
  # Run prediction
91
+ result = predict(pil_image, processing_res_choice)
92
+ depth_colored = result.depth_colored
93
+ normal_colored = result.normal_colored
94
 
95
  # Write depth frame
96
  depth_frame = np.array(depth_colored)
 
107
  out_depth.release()
108
  out_normal.release()
109
 
110
+ # Return paths for download
111
  return tmp_depth.name, tmp_normal.name
112
 
113
  # Build Gradio interface