Spaces:
Alexspy19
/
Runtime error

Alexspy19 commited on
Commit
5fd5a4a
·
verified ·
1 Parent(s): ccc184f

Update LHM/models/rendering/gs_renderer.py

Browse files
Files changed (1) hide show
  1. LHM/models/rendering/gs_renderer.py +71 -7
LHM/models/rendering/gs_renderer.py CHANGED
@@ -210,19 +210,79 @@ class GaussianModel:
210
  l.append("rot_{}".format(i))
211
  return l
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  def save_ply(self, path: str):
214
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  xyz = self.xyz.detach().cpu().numpy()
216
  normals = np.zeros_like(xyz)
217
-
218
  if self.use_rgb:
219
  shs = RGB2SH(self.shs)
220
  else:
221
  shs = self.shs
222
-
223
  features_dc = shs[:, :1]
224
  features_rest = shs[:, 1:]
225
-
226
  f_dc = (
227
  features_dc.float().detach().flatten(start_dim=1).contiguous().cpu().numpy()
228
  )
@@ -240,14 +300,14 @@ class GaussianModel:
240
  .cpu()
241
  .numpy()
242
  )
243
-
244
  scale = np.log(self.scaling.detach().cpu().numpy())
245
  rotation = self.rotation.detach().cpu().numpy()
246
-
247
  dtype_full = [
248
  (attribute, "f4") for attribute in self.construct_list_of_attributes()
249
  ]
250
-
251
  elements = np.empty(xyz.shape[0], dtype=dtype_full)
252
  attributes = np.concatenate(
253
  (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1
@@ -255,6 +315,10 @@ class GaussianModel:
255
  elements[:] = list(map(tuple, attributes))
256
  el = PlyElement.describe(elements, "vertex")
257
  PlyData([el]).write(path)
 
 
 
 
258
 
259
  def load_ply(self, path):
260
 
 
210
  l.append("rot_{}".format(i))
211
  return l
212
 
213
+ # def save_ply(self, path: str):
214
+
215
+ # xyz = self.xyz.detach().cpu().numpy()
216
+ # normals = np.zeros_like(xyz)
217
+
218
+ # if self.use_rgb:
219
+ # shs = RGB2SH(self.shs)
220
+ # else:
221
+ # shs = self.shs
222
+
223
+ # features_dc = shs[:, :1]
224
+ # features_rest = shs[:, 1:]
225
+
226
+ # f_dc = (
227
+ # features_dc.float().detach().flatten(start_dim=1).contiguous().cpu().numpy()
228
+ # )
229
+ # f_rest = (
230
+ # features_rest.float()
231
+ # .detach()
232
+ # .flatten(start_dim=1)
233
+ # .contiguous()
234
+ # .cpu()
235
+ # .numpy()
236
+ # )
237
+ # opacities = (
238
+ # inverse_sigmoid(torch.clamp(self.opacity, 1e-3, 1 - 1e-3))
239
+ # .detach()
240
+ # .cpu()
241
+ # .numpy()
242
+ # )
243
+
244
+ # scale = np.log(self.scaling.detach().cpu().numpy())
245
+ # rotation = self.rotation.detach().cpu().numpy()
246
+
247
+ # dtype_full = [
248
+ # (attribute, "f4") for attribute in self.construct_list_of_attributes()
249
+ # ]
250
+
251
+ # elements = np.empty(xyz.shape[0], dtype=dtype_full)
252
+ # attributes = np.concatenate(
253
+ # (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1
254
+ # )
255
+ # elements[:] = list(map(tuple, attributes))
256
+ # el = PlyElement.describe(elements, "vertex")
257
+ # PlyData([el]).write(path)
258
  def save_ply(self, path: str):
259
+ """
260
+ Save the Gaussian Splatting model as a PLY file in a location accessible for download.
261
+
262
+ Args:
263
+ path (str): Path where to save the PLY file
264
+
265
+ Returns:
266
+ str: Path to the saved PLY file
267
+ """
268
+ import os
269
+
270
+ # Ensure output directory exists
271
+ output_dir = os.path.dirname(path)
272
+ os.makedirs(output_dir, exist_ok=True)
273
+
274
+ # Save original PLY data
275
  xyz = self.xyz.detach().cpu().numpy()
276
  normals = np.zeros_like(xyz)
277
+
278
  if self.use_rgb:
279
  shs = RGB2SH(self.shs)
280
  else:
281
  shs = self.shs
282
+
283
  features_dc = shs[:, :1]
284
  features_rest = shs[:, 1:]
285
+
286
  f_dc = (
287
  features_dc.float().detach().flatten(start_dim=1).contiguous().cpu().numpy()
288
  )
 
300
  .cpu()
301
  .numpy()
302
  )
303
+
304
  scale = np.log(self.scaling.detach().cpu().numpy())
305
  rotation = self.rotation.detach().cpu().numpy()
306
+
307
  dtype_full = [
308
  (attribute, "f4") for attribute in self.construct_list_of_attributes()
309
  ]
310
+
311
  elements = np.empty(xyz.shape[0], dtype=dtype_full)
312
  attributes = np.concatenate(
313
  (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1
 
315
  elements[:] = list(map(tuple, attributes))
316
  el = PlyElement.describe(elements, "vertex")
317
  PlyData([el]).write(path)
318
+
319
+ print(f"Model saved to {path}")
320
+ return path
321
+
322
 
323
  def load_ply(self, path):
324