Commit
·
566832c
1
Parent(s):
fdeb0ad
Enhance token handling in FrozenOpenCLIPEmbedder by adding device checks and conversions. This update ensures that input tokens are correctly processed as torch tensors on the appropriate device, improving compatibility and performance across different hardware configurations.
Browse files
imagedream/ldm/modules/encoders/modules.py
CHANGED
@@ -275,8 +275,16 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module):
|
|
275 |
encoder_states = encoder_states + (x, )
|
276 |
return encoder_states
|
277 |
|
278 |
-
def encode_with_transformer(self,
|
279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
x = x + self.model.positional_embedding
|
281 |
x = x.permute(1, 0, 2) # NLD -> LND
|
282 |
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
|
|
275 |
encoder_states = encoder_states + (x, )
|
276 |
return encoder_states
|
277 |
|
278 |
+
def encode_with_transformer(self, tokens):
|
279 |
+
# Debug: print device info
|
280 |
+
print(f"[DEBUG] tokens type: {type(tokens)}, tokens device: {getattr(tokens, 'device', 'N/A')}")
|
281 |
+
print(f"[DEBUG] embedding weight device: {self.model.token_embedding.weight.device}")
|
282 |
+
# Ensure tokens is a torch tensor and on the correct device
|
283 |
+
if not isinstance(tokens, torch.Tensor):
|
284 |
+
tokens = torch.tensor(tokens, device=self.model.token_embedding.weight.device)
|
285 |
+
else:
|
286 |
+
tokens = tokens.to(self.model.token_embedding.weight.device)
|
287 |
+
x = self.model.token_embedding(tokens) # [batch_size, n_ctx, d_model]
|
288 |
x = x + self.model.positional_embedding
|
289 |
x = x.permute(1, 0, 2) # NLD -> LND
|
290 |
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|