Alesteba commited on
Commit
e352353
·
1 Parent(s): 2e25e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -1
app.py CHANGED
@@ -70,7 +70,9 @@ def get_rays(height, width, focal, pose):
70
 
71
 
72
  def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=False):
 
73
  """Renders the rays and flattens it.
 
74
  Args:
75
  ray_origins: The origin points for rays.
76
  ray_directions: The direction unit vectors for the rays.
@@ -78,13 +80,18 @@ def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=F
78
  far: The far bound of the volumetric scene.
79
  num_samples: Number of sample points in a ray.
80
  rand: Choice for randomising the sampling strategy.
 
81
  Returns:
82
  Tuple of flattened rays and sample points on each rays.
83
  """
 
84
  # Compute 3D query points.
85
  # Equation: r(t) = o+td -> Building the "t" here.
 
86
  t_vals = tf.linspace(near, far, num_samples)
 
87
  if rand:
 
88
  # Inject uniform noise into sample space to make the sampling
89
  # continuous.
90
  shape = list(ray_origins.shape[:-1]) + [num_samples]
@@ -92,6 +99,7 @@ def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=F
92
  t_vals = t_vals + noise
93
 
94
  # Equation: r(t) = o + td -> Building the "r" here.
 
95
  rays = ray_origins[..., None, :] + (
96
  ray_directions[..., None, :] * t_vals[..., None]
97
  )
@@ -101,13 +109,17 @@ def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=F
101
 
102
 
103
  def map_fn(pose):
 
104
  """Maps individual pose to flattened rays and sample points.
 
105
  Args:
106
  pose: The pose matrix of the camera.
 
107
  Returns:
108
  Tuple of flattened rays and sample points corresponding to the
109
  camera pose.
110
  """
 
111
  (ray_origins, ray_directions) = get_rays(height=H, width=W, focal=focal, pose=pose)
112
  (rays_flat, t_vals) = render_flat_rays(
113
  ray_origins=ray_origins,
@@ -117,11 +129,14 @@ def map_fn(pose):
117
  num_samples=NUM_SAMPLES,
118
  rand=True,
119
  )
 
120
  return (rays_flat, t_vals)
121
 
122
 
123
  def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
 
124
  """Generates the RGB image and depth map from model prediction.
 
125
  Args:
126
  model: The MLP model that is trained to predict the rgb and
127
  volume density of the volumetric scene.
@@ -130,9 +145,11 @@ def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
130
  t_vals: The sample points for the rays.
131
  rand: Choice to randomise the sampling strategy.
132
  train: Whether the model is in the training or testing phase.
 
133
  Returns:
134
  Tuple of rgb image and depth map.
135
  """
 
136
  # Get the predictions from the nerf model and reshape it.
137
  if train:
138
  predictions = model(rays_flat)
@@ -147,6 +164,7 @@ def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
147
  # Get the distance of adjacent intervals.
148
  delta = t_vals[..., 1:] - t_vals[..., :-1]
149
  # delta shape = (num_samples)
 
150
  if rand:
151
  delta = tf.concat(
152
  [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, H, W, 1))], axis=-1
@@ -171,7 +189,6 @@ def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
171
  depth_map = tf.reduce_sum(weights * t_vals[:, None, None], axis=-1)
172
  return (rgb, depth_map)
173
 
174
-
175
  def get_translation_t(t):
176
  """Get the translation matrix for movement in t."""
177
  matrix = [
 
70
 
71
 
72
  def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=False):
73
+
74
  """Renders the rays and flattens it.
75
+
76
  Args:
77
  ray_origins: The origin points for rays.
78
  ray_directions: The direction unit vectors for the rays.
 
80
  far: The far bound of the volumetric scene.
81
  num_samples: Number of sample points in a ray.
82
  rand: Choice for randomising the sampling strategy.
83
+
84
  Returns:
85
  Tuple of flattened rays and sample points on each rays.
86
  """
87
+
88
  # Compute 3D query points.
89
  # Equation: r(t) = o+td -> Building the "t" here.
90
+
91
  t_vals = tf.linspace(near, far, num_samples)
92
+
93
  if rand:
94
+
95
  # Inject uniform noise into sample space to make the sampling
96
  # continuous.
97
  shape = list(ray_origins.shape[:-1]) + [num_samples]
 
99
  t_vals = t_vals + noise
100
 
101
  # Equation: r(t) = o + td -> Building the "r" here.
102
+
103
  rays = ray_origins[..., None, :] + (
104
  ray_directions[..., None, :] * t_vals[..., None]
105
  )
 
109
 
110
 
111
  def map_fn(pose):
112
+
113
  """Maps individual pose to flattened rays and sample points.
114
+
115
  Args:
116
  pose: The pose matrix of the camera.
117
+
118
  Returns:
119
  Tuple of flattened rays and sample points corresponding to the
120
  camera pose.
121
  """
122
+
123
  (ray_origins, ray_directions) = get_rays(height=H, width=W, focal=focal, pose=pose)
124
  (rays_flat, t_vals) = render_flat_rays(
125
  ray_origins=ray_origins,
 
129
  num_samples=NUM_SAMPLES,
130
  rand=True,
131
  )
132
+
133
  return (rays_flat, t_vals)
134
 
135
 
136
  def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
137
+
138
  """Generates the RGB image and depth map from model prediction.
139
+
140
  Args:
141
  model: The MLP model that is trained to predict the rgb and
142
  volume density of the volumetric scene.
 
145
  t_vals: The sample points for the rays.
146
  rand: Choice to randomise the sampling strategy.
147
  train: Whether the model is in the training or testing phase.
148
+
149
  Returns:
150
  Tuple of rgb image and depth map.
151
  """
152
+
153
  # Get the predictions from the nerf model and reshape it.
154
  if train:
155
  predictions = model(rays_flat)
 
164
  # Get the distance of adjacent intervals.
165
  delta = t_vals[..., 1:] - t_vals[..., :-1]
166
  # delta shape = (num_samples)
167
+
168
  if rand:
169
  delta = tf.concat(
170
  [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, H, W, 1))], axis=-1
 
189
  depth_map = tf.reduce_sum(weights * t_vals[:, None, None], axis=-1)
190
  return (rgb, depth_map)
191
 
 
192
  def get_translation_t(t):
193
  """Get the translation matrix for movement in t."""
194
  matrix = [