cocktailpeanut commited on
Commit
bd5e995
Β·
1 Parent(s): 1c2ea32
Files changed (1) hide show
  1. ldm/modules/encoders/modules.py +9 -1
ldm/modules/encoders/modules.py CHANGED
@@ -13,6 +13,13 @@ from ldm.util import default, instantiate_from_config
13
  from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
14
  import clip
15
 
 
 
 
 
 
 
 
16
  class AbstractEncoder(nn.Module):
17
  def __init__(self):
18
  super().__init__()
@@ -30,7 +37,8 @@ def disabled_train(self, mode=True):
30
 
31
  class FrozenCLIPEmbedder(AbstractEncoder):
32
  """Uses the CLIP transformer encoder for text (from huggingface)"""
33
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
 
34
  super().__init__()
35
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
36
  self.transformer = CLIPTextModel.from_pretrained(version)
 
13
  from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
14
  import clip
15
 
16
+ if torch.cuda.is_available():
17
+ _device = "cuda"
18
+ elif torch.backends.mps.is_available():
19
+ _device = "mps"
20
+ else:
21
+ _device = "cpu"
22
+
23
  class AbstractEncoder(nn.Module):
24
  def __init__(self):
25
  super().__init__()
 
37
 
38
  class FrozenCLIPEmbedder(AbstractEncoder):
39
  """Uses the CLIP transformer encoder for text (from huggingface)"""
40
+ #def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
41
+ def __init__(self, version="openai/clip-vit-large-patch14", device=_device, max_length=77): # clip-vit-base-patch32
42
  super().__init__()
43
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
44
  self.transformer = CLIPTextModel.from_pretrained(version)