Prathm commited on
Commit
d28e895
·
1 Parent(s): d0f46df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -50,8 +50,8 @@ def clip_optimized_latent(text, seed, iterations=25, lr=0.1):
50
  params = [torch.nn.Parameter(latent_vector[i], requires_grad=True) for i in range(len(latent_vector))]
51
  optimizer = Adam(params, lr=lr)
52
 
53
- #with torch.no_grad():
54
- # text_features = clip_model.encode_text(text_input)
55
 
56
  #pbar = tqdm(range(iterations), dynamic_ncols=True)
57
 
@@ -67,16 +67,16 @@ def clip_optimized_latent(text, seed, iterations=25, lr=0.1):
67
  #image = clip_preprocess(Image.fromarray((image_np * 255).astype(np.uint8))).unsqueeze(0).to(device)
68
 
69
  # Extract features from the image
70
- #image_features = clip_model.encode_image(image)
71
 
72
  # Calculate the loss and backpropagate
73
- loss = 1 - clip_model(image, text_input)[0] / 100
74
- #loss = -torch.cosine_similarity(text_features, image_features).mean()
75
  loss.backward()
76
  optimizer.step()
77
 
78
  #pbar.set_description(f"Loss: {loss.item()}") # Update the progress bar to show the current loss
79
- w = [param.detach().cpu().numpy() for param in params]
 
80
 
81
  return w
82
 
 
50
  params = [torch.nn.Parameter(latent_vector[i], requires_grad=True) for i in range(len(latent_vector))]
51
  optimizer = Adam(params, lr=lr)
52
 
53
+ with torch.no_grad():
54
+ text_features = clip_model.encode_text(text_input)
55
 
56
  #pbar = tqdm(range(iterations), dynamic_ncols=True)
57
 
 
67
  #image = clip_preprocess(Image.fromarray((image_np * 255).astype(np.uint8))).unsqueeze(0).to(device)
68
 
69
  # Extract features from the image
70
+ image_features = clip_model.encode_image(image)
71
 
72
  # Calculate the loss and backpropagate
73
+ loss = -torch.cosine_similarity(text_features, image_features).mean()
 
74
  loss.backward()
75
  optimizer.step()
76
 
77
  #pbar.set_description(f"Loss: {loss.item()}") # Update the progress bar to show the current loss
78
+ print(f"Loss: {loss.item()}")
79
+ w = [param.detach().cpu().numpy() for param in params]
80
 
81
  return w
82