doublecheck inference_state["device"]
Browse files
app.py
CHANGED
@@ -257,7 +257,8 @@ def segment_with_points(
|
|
257 |
evt: gr.SelectData,
|
258 |
):
|
259 |
predictor.to("cpu")
|
260 |
-
inference_state
|
|
|
261 |
input_points.append(evt.index)
|
262 |
print(f"TRACKING INPUT POINT: {input_points}")
|
263 |
|
@@ -344,7 +345,8 @@ def propagate_to_all(
|
|
344 |
torch.backends.cudnn.allow_tf32 = True
|
345 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
346 |
predictor.to("cuda")
|
347 |
-
inference_state
|
|
|
348 |
|
349 |
if len(input_points) == 0 or video_in is None or inference_state is None:
|
350 |
return None
|
|
|
257 |
evt: gr.SelectData,
|
258 |
):
|
259 |
predictor.to("cpu")
|
260 |
+
if inference_state:
|
261 |
+
inference_state["device"] = predictor.device
|
262 |
input_points.append(evt.index)
|
263 |
print(f"TRACKING INPUT POINT: {input_points}")
|
264 |
|
|
|
345 |
torch.backends.cudnn.allow_tf32 = True
|
346 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
347 |
predictor.to("cuda")
|
348 |
+
if inference_state:
|
349 |
+
inference_state["device"] = predictor.device
|
350 |
|
351 |
if len(input_points) == 0 or video_in is None or inference_state is None:
|
352 |
return None
|