Spaces:
Runtime error
Runtime error
File size: 2,301 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from xformers.components.attention import AttentionMask
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="This test requires a CUDA device"
)
def test_mask_creation():
# Check that we can create from boolean
bool_mask = torch.rand((256, 256)) > 0.5
additive_mask = AttentionMask.from_bool(bool_mask)
assert (bool_mask == additive_mask.to_bool()).all()
bool_mask = torch.rand((2, 256, 256)) > 0.5
additive_mask = AttentionMask.from_bool(bool_mask)
assert (bool_mask == additive_mask.to_bool()).all()
assert additive_mask.ndim == bool_mask.ndim
# Check that we can create from multiplicative
ref_mask = torch.randint(0, 2, (256, 256))
mul_mask = ref_mask.float()
additive_mask = AttentionMask.from_multiplicative(mul_mask)
assert (ref_mask.bool() == additive_mask.to_bool()).all()
# Check the causal mask
causal_mask = AttentionMask.make_causal(256, 256)
assert (torch.tril(torch.ones(256, 256)).bool() == causal_mask.to_bool()).all()
assert causal_mask.is_causal
causal_mask = AttentionMask.make_causal(256)
assert (torch.tril(torch.ones(256, 256)).bool() == causal_mask.to_bool()).all()
causal_mask = AttentionMask.make_causal(256, 128)
assert (torch.tril(torch.ones(256, 128)).bool() == causal_mask.to_bool()).all()
# Check that we can add masks
bool_mask_1 = torch.rand((256, 256)) > 0.5
add_mask_1 = AttentionMask.from_bool(bool_mask_1)
bool_mask_2 = torch.rand((256, 256)) > 0.5
add_mask_2 = AttentionMask.from_bool(bool_mask_2)
assert ((add_mask_1 + add_mask_2).to_bool() == (bool_mask_1 & bool_mask_2)).all()
# Check type handling
additive_mask = AttentionMask.from_bool(torch.rand((256, 256)) > 0.5)
additive_mask = additive_mask.to(device=torch.device("cuda"))
assert "cuda" in str(additive_mask.values.device)
# Check that the causal flag is maintained
additive_mask = AttentionMask.make_causal(256, 256)
additive_mask = additive_mask.to(device=torch.device("cuda"))
assert additive_mask.is_causal
|