chongzhou commited on
Commit
282a45a
·
1 Parent(s): db1d5c6

make predictor global and remove bf16

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -72,6 +72,7 @@ examples = [
72
  OBJ_ID = 0
73
  sam2_checkpoint = "checkpoints/edgetam.pt"
74
  model_cfg = "edgetam.yaml"
 
75
 
76
 
77
  def get_video_fps(video_path):
@@ -226,7 +227,6 @@ def preprocess_video_in(
226
  input_points = []
227
  input_labels = []
228
 
229
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
230
  inference_state = predictor.init_state(
231
  offload_video_to_cpu=True,
232
  offload_state_to_cpu=True,
@@ -255,7 +255,6 @@ def segment_with_points(
255
  inference_state,
256
  evt: gr.SelectData,
257
  ):
258
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
259
  input_points.append(evt.index)
260
  print(f"TRACKING INPUT POINT: {input_points}")
261
 
@@ -337,12 +336,13 @@ def propagate_to_all(
337
  input_points,
338
  inference_state,
339
  ):
340
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
341
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
342
  if torch.cuda.get_device_properties(0).major >= 8:
343
  torch.backends.cuda.matmul.allow_tf32 = True
344
  torch.backends.cudnn.allow_tf32 = True
345
 
 
 
346
  if len(input_points) == 0 or video_in is None or inference_state is None:
347
  return None
348
  # run propagation throughout the video and collect the results in a dict
 
72
  OBJ_ID = 0
73
  sam2_checkpoint = "checkpoints/edgetam.pt"
74
  model_cfg = "edgetam.yaml"
75
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
76
 
77
 
78
  def get_video_fps(video_path):
 
227
  input_points = []
228
  input_labels = []
229
 
 
230
  inference_state = predictor.init_state(
231
  offload_video_to_cpu=True,
232
  offload_state_to_cpu=True,
 
255
  inference_state,
256
  evt: gr.SelectData,
257
  ):
 
258
  input_points.append(evt.index)
259
  print(f"TRACKING INPUT POINT: {input_points}")
260
 
 
336
  input_points,
337
  inference_state,
338
  ):
339
+ # torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
 
340
  if torch.cuda.get_device_properties(0).major >= 8:
341
  torch.backends.cuda.matmul.allow_tf32 = True
342
  torch.backends.cudnn.allow_tf32 = True
343
 
344
+ predictor.to("cuda")
345
+
346
  if len(input_points) == 0 or video_in is None or inference_state is None:
347
  return None
348
  # run propagation throughout the video and collect the results in a dict