chongzhou commited on
Commit
7209747
·
1 Parent(s): 7563789

doublecheck inference_state["device"]

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -257,7 +257,8 @@ def segment_with_points(
257
  evt: gr.SelectData,
258
  ):
259
  predictor.to("cpu")
260
- inference_state["device"] = predictor.device
 
261
  input_points.append(evt.index)
262
  print(f"TRACKING INPUT POINT: {input_points}")
263
 
@@ -344,7 +345,8 @@ def propagate_to_all(
344
  torch.backends.cudnn.allow_tf32 = True
345
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
346
  predictor.to("cuda")
347
- inference_state["device"] = predictor.device
 
348
 
349
  if len(input_points) == 0 or video_in is None or inference_state is None:
350
  return None
 
257
  evt: gr.SelectData,
258
  ):
259
  predictor.to("cpu")
260
+ if inference_state:
261
+ inference_state["device"] = predictor.device
262
  input_points.append(evt.index)
263
  print(f"TRACKING INPUT POINT: {input_points}")
264
 
 
345
  torch.backends.cudnn.allow_tf32 = True
346
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
347
  predictor.to("cuda")
348
+ if inference_state:
349
+ inference_state["device"] = predictor.device
350
 
351
  if len(input_points) == 0 or video_in is None or inference_state is None:
352
  return None