Spaces:
mashroo
/
Runtime error

YoussefAnso commited on
Commit
566832c
·
1 Parent(s): fdeb0ad

Enhance token handling in FrozenOpenCLIPEmbedder by adding device checks and conversions. This update ensures that input tokens are correctly processed as torch tensors on the appropriate device, improving compatibility and performance across different hardware configurations.

Browse files
imagedream/ldm/modules/encoders/modules.py CHANGED
@@ -275,8 +275,16 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module):
275
  encoder_states = encoder_states + (x, )
276
  return encoder_states
277
 
278
- def encode_with_transformer(self, text):
279
- x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
 
 
 
 
 
 
 
 
280
  x = x + self.model.positional_embedding
281
  x = x.permute(1, 0, 2) # NLD -> LND
282
  x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
 
275
  encoder_states = encoder_states + (x, )
276
  return encoder_states
277
 
278
+ def encode_with_transformer(self, tokens):
279
+ # Debug: print device info
280
+ print(f"[DEBUG] tokens type: {type(tokens)}, tokens device: {getattr(tokens, 'device', 'N/A')}")
281
+ print(f"[DEBUG] embedding weight device: {self.model.token_embedding.weight.device}")
282
+ # Ensure tokens is a torch tensor and on the correct device
283
+ if not isinstance(tokens, torch.Tensor):
284
+ tokens = torch.tensor(tokens, device=self.model.token_embedding.weight.device)
285
+ else:
286
+ tokens = tokens.to(self.model.token_embedding.weight.device)
287
+ x = self.model.token_embedding(tokens) # [batch_size, n_ctx, d_model]
288
  x = x + self.model.positional_embedding
289
  x = x.permute(1, 0, 2) # NLD -> LND
290
  x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)