chongzhou commited on
Commit
3cc85d3
·
1 Parent(s): 4d4080b

remove torch.cuda.is_available

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -226,14 +226,13 @@ def preprocess_video_in(
226
  input_labels = []
227
 
228
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
229
- if torch.cuda.is_available():
230
- predictor.to("cuda")
231
- if inference_state:
232
- inference_state["device"] = "cuda"
233
- if torch.cuda.get_device_properties(0).major >= 8:
234
- torch.backends.cuda.matmul.allow_tf32 = True
235
- torch.backends.cudnn.allow_tf32 = True
236
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
237
  inference_state = predictor.init_state(video_path=video_path)
238
 
239
  return [
@@ -259,14 +258,13 @@ def segment_with_points(
259
  evt: gr.SelectData,
260
  ):
261
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
262
- if torch.cuda.is_available():
263
- predictor.to("cuda")
264
- if inference_state:
265
- inference_state["device"] = "cuda"
266
- if torch.cuda.get_device_properties(0).major >= 8:
267
- torch.backends.cuda.matmul.allow_tf32 = True
268
- torch.backends.cudnn.allow_tf32 = True
269
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
270
 
271
  input_points.append(evt.index)
272
  print(f"TRACKING INPUT POINT: {input_points}")
@@ -351,14 +349,13 @@ def propagate_to_all(
351
  inference_state,
352
  ):
353
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
354
- if torch.cuda.is_available():
355
- predictor.to("cuda")
356
- if inference_state:
357
- inference_state["device"] = "cuda"
358
- if torch.cuda.get_device_properties(0).major >= 8:
359
- torch.backends.cuda.matmul.allow_tf32 = True
360
- torch.backends.cudnn.allow_tf32 = True
361
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
362
 
363
  if len(input_points) == 0 or video_in is None or inference_state is None:
364
  return None
 
226
  input_labels = []
227
 
228
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
229
+ predictor.to("cuda")
230
+ if inference_state:
231
+ inference_state["device"] = "cuda"
232
+ if torch.cuda.get_device_properties(0).major >= 8:
233
+ torch.backends.cuda.matmul.allow_tf32 = True
234
+ torch.backends.cudnn.allow_tf32 = True
235
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
 
236
  inference_state = predictor.init_state(video_path=video_path)
237
 
238
  return [
 
258
  evt: gr.SelectData,
259
  ):
260
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
261
+ predictor.to("cuda")
262
+ if inference_state:
263
+ inference_state["device"] = "cuda"
264
+ if torch.cuda.get_device_properties(0).major >= 8:
265
+ torch.backends.cuda.matmul.allow_tf32 = True
266
+ torch.backends.cudnn.allow_tf32 = True
267
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
 
268
 
269
  input_points.append(evt.index)
270
  print(f"TRACKING INPUT POINT: {input_points}")
 
349
  inference_state,
350
  ):
351
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
352
+ predictor.to("cuda")
353
+ if inference_state:
354
+ inference_state["device"] = "cuda"
355
+ if torch.cuda.get_device_properties(0).major >= 8:
356
+ torch.backends.cuda.matmul.allow_tf32 = True
357
+ torch.backends.cudnn.allow_tf32 = True
358
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
 
359
 
360
  if len(input_points) == 0 or video_in is None or inference_state is None:
361
  return None