alexnasa commited on
Commit
e1a5218
·
verified ·
1 Parent(s): 7130239

Update osediff_sd3.py

Browse files
Files changed (1) hide show
  1. osediff_sd3.py +2 -2
osediff_sd3.py CHANGED
@@ -392,7 +392,7 @@ class OSEDiff_SD3_GEN(torch.nn.Module):
392
  # Add lora to transformer
393
  print('Adding Lora to OSEDiff_SD3_GEN')
394
  self.transformer_gen = copy.deepcopy(self.model.transformer)
395
- self.transformer_gen.to('cuda:1')
396
  # self.transformer_gen = self.transformer_gen.float()
397
 
398
  self.transformer_gen.requires_grad_(False)
@@ -468,7 +468,7 @@ class OSEDiff_SD3_REG(torch.nn.Module):
468
  # Add lora to transformer
469
  print('Adding Lora to OSEDiff_SD3_REG')
470
  self.transformer_reg = copy.deepcopy(self.transformer_org)
471
- self.transformer_reg.to('cuda:1')
472
 
473
  self.transformer_reg.requires_grad_(False)
474
  self.transformer_reg.train()
 
392
  # Add lora to transformer
393
  print('Adding Lora to OSEDiff_SD3_GEN')
394
  self.transformer_gen = copy.deepcopy(self.model.transformer)
395
+ self.transformer_gen.to('cuda')
396
  # self.transformer_gen = self.transformer_gen.float()
397
 
398
  self.transformer_gen.requires_grad_(False)
 
468
  # Add lora to transformer
469
  print('Adding Lora to OSEDiff_SD3_REG')
470
  self.transformer_reg = copy.deepcopy(self.transformer_org)
471
+ self.transformer_reg.to('cuda')
472
 
473
  self.transformer_reg.requires_grad_(False)
474
  self.transformer_reg.train()