dineshsai07 commited on
Commit
f34225c
·
verified ·
1 Parent(s): a167e83

interface files

Browse files
Files changed (2) hide show
  1. scripts/app.py +52 -0
  2. scripts/generate.py +167 -0
scripts/app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ from generate import generate_image
4
+ import os
5
+
6
+ # Set up page config
7
+ st.set_page_config(page_title="Visual Reconstruction from Brain", layout="centered")
8
+
9
+ # Configure Streamlit to use the correct host
10
+ import streamlit.web.cli as stcli
11
+ import sys
12
+ sys.argv = ["streamlit", "run", "scripts/app.py", "--server.address", "10.192.12.247", "--server.port", "8501", "--browser.serverAddress", "10.192.12.247"]
13
+
14
+ st.title("🧠 Imagine an Image!")
15
+
16
+ # Subject selection
17
+ sub = st.selectbox("Select Subject", options=[1, 2, 5, 7], index=0)
18
+
19
+ # Image ID input
20
+ image_id = st.number_input("Enter Image ID", min_value=0, step=1)
21
+
22
+ original_path = f'data/nsddata_stimuli/test_images/{image_id}.png'
23
+ if os.path.exists(original_path):
24
+ st.image(original_path, caption="Original Image", use_column_width=True)
25
+ else:
26
+ st.warning("Original image not found.")
27
+ # Text prompt
28
+ annot = st.text_input("Describe what you imagined", placeholder="e.g., a dog under a tree")
29
+
30
+ # Parameters
31
+ strength = st.slider("Diffusion Strength", 0.0, 1.0, 0.75, 0.05)
32
+ mixing = st.slider("Mixing Strength", 0.0, 1.0, 0.4, 0.05)
33
+
34
+ # Submit button
35
+ if st.button("Reconstruct Image"):
36
+ with st.spinner("Reconstructing... please wait"):
37
+ try:
38
+ original_path, imagined_path = generate_image(sub, image_id, annot, strength, mixing)
39
+ # if os.path.exists(original_path):
40
+ # st.image(original_path, caption="Original Image", use_column_width=True)
41
+ # else:
42
+ # st.warning("Original image not found.")
43
+ if os.path.exists(imagined_path):
44
+ st.image(imagined_path, caption="Imagined Reconstruction", use_column_width=True)
45
+ else:
46
+ st.warning("Imagined image not found.")
47
+ except Exception as e:
48
+ st.error(f"⚠️ Error during generation: {e}")
49
+
50
+ # Optional: For cloud users
51
+ st.markdown("---")
52
+ # st.markdown("🔗 Access the app at: http://10.192.12.247:8501")
scripts/generate.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate.py
2
+ import sys
3
+ sys.path.append('versatile_diffusion')
4
+ import os
5
+ import os.path as osp
6
+ import PIL
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ import numpy as np
10
+ import numpy.random as npr
11
+
12
+ import torch
13
+ import torchvision.transforms as tvtrans
14
+ from lib.cfg_helper import model_cfg_bank
15
+ from lib.model_zoo import get_model
16
+ from lib.model_zoo.ddim_vd import DDIMSampler_VD
17
+ from lib.experiments.sd_default import color_adjust, auto_merge_imlist
18
+ from torch.utils.data import DataLoader, Dataset
19
+
20
+ from lib.model_zoo.vd import VD
21
+ from lib.cfg_holder import cfg_unique_holder as cfguh
22
+ from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml
23
+ import matplotlib.pyplot as plt
24
+ from skimage.transform import resize, downscale_local_mean
25
+
26
+
27
+ def regularize_image(x):
28
+ BICUBIC = PIL.Image.Resampling.BICUBIC
29
+ if isinstance(x, str):
30
+ x = Image.open(x).resize([512, 512], resample=BICUBIC)
31
+ x = tvtrans.ToTensor()(x)
32
+ elif isinstance(x, PIL.Image.Image):
33
+ x = x.resize([512, 512], resample=BICUBIC)
34
+ x = tvtrans.ToTensor()(x)
35
+ elif isinstance(x, np.ndarray):
36
+ x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC)
37
+ x = tvtrans.ToTensor()(x)
38
+ elif isinstance(x, torch.Tensor):
39
+ pass
40
+ else:
41
+ assert False, 'Unknown image type'
42
+
43
+ assert (x.shape[1]==512) & (x.shape[2]==512), \
44
+ 'Wrong image size'
45
+ return x
46
+
47
+ # Load model once globally
48
+ cfgm_name = 'vd_noema'
49
+ sampler = DDIMSampler_VD
50
+ pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth'
51
+ cfgm = model_cfg_bank()(cfgm_name)
52
+ net = get_model()(cfgm)
53
+ sd = torch.load(pth, map_location='cpu')
54
+ net.load_state_dict(sd, strict=False)
55
+
56
+ # Ensuring proper GPU device assignment, using cuda:0 for all tensor assignments
57
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
58
+
59
+ # Move models and data to GPU (cuda:0)
60
+ net.clip.cuda(0)
61
+ net.autokl.cuda(0)
62
+
63
+ sampler = sampler(net)
64
+ sampler.model.model.diffusion_model.device = device
65
+ sampler.model.model.diffusion_model.half().to(device)
66
+ batch_size = 1
67
+
68
+ # Load predicted features and move them to GPU
69
+ # pred_text = np.load('data/predicted_features/subj{:02d}/nsd_cliptext_predtest_nsdgeneral.npy'.format(sub))
70
+ # pred_text = torch.tensor(pred_text).half().to(device)
71
+
72
+ # pred_vision = np.load('data/predicted_features/subj{:02d}/nsd_clipvision_predtest_nsdgeneral.npy'.format(sub))
73
+ # pred_vision = torch.tensor(pred_vision).half().to(device)
74
+
75
+ n_samples = 1
76
+ ddim_steps = 50
77
+ ddim_eta = 0
78
+ scale = 7.5
79
+ xtype = 'image'
80
+ ctype = 'prompt'
81
+ net.autokl.half()
82
+
83
+ torch.manual_seed(0)
84
+
85
+ net.clip = net.clip.to(device)
86
+
87
+ def generate_image(sub, image_id, annot, strength=0.75, mixing=0.4):
88
+
89
+ im_id = image_id
90
+
91
+ pred_text = np.load(f'data/predicted_features/subj{sub:02d}/nsd_cliptext_predtest_nsdgeneral.npy')
92
+ pred_vision = np.load(f'data/predicted_features/subj{sub:02d}/nsd_clipvision_predtest_nsdgeneral.npy')
93
+ pred_text = torch.tensor(pred_text).half().to(device)
94
+ pred_vision = torch.tensor(pred_vision).half().to(device)
95
+
96
+ zim = Image.open(f'results/vdvae/subj{sub:02d}/{image_id}.png')
97
+ test_img = Image.open(f'data/nsddata_stimuli/test_images/{image_id}.png')
98
+ test_img_path = f'scripts/images/original_image.png'
99
+ test_img.save(test_img_path)
100
+
101
+ zim = regularize_image(zim)
102
+ zin = zim * 2 - 1
103
+ zin = zin.unsqueeze(0).to(device).half()
104
+ init_latent = net.autokl_encode(zin)
105
+
106
+ sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
107
+ t_enc = int(strength * ddim_steps)
108
+ z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]).to(device))
109
+
110
+ dummy = ''
111
+ utx = net.clip_encode_text(dummy).to(device).half()
112
+ dummy = torch.zeros((1, 3, 224, 224)).to(device)
113
+ uim = net.clip_encode_vision(dummy).to(device).half()
114
+
115
+ z_enc = z_enc.to(device)
116
+
117
+ # Sample configuration for diffusion
118
+ h, w = 512,512
119
+ shape = [n_samples, 4, h//8, w//8]
120
+
121
+ pred_text = np.load(f'data/predicted_features/subj{sub:02d}/nsd_cliptext_predtest_nsdgeneral.npy')
122
+ with torch.no_grad():
123
+ pred_text[image_id] = net.clip_encode_text([annot]).to('cpu').numpy().mean(0)
124
+ pred_text = torch.tensor(pred_text).half().to(device)
125
+ ctx = pred_text[image_id].unsqueeze(0).to(device)
126
+ cim = pred_vision[image_id].unsqueeze(0).to(device)
127
+
128
+ z = sampler.decode_dc(
129
+ x_latent=z_enc,
130
+ first_conditioning=[uim, cim],
131
+ second_conditioning=[utx, ctx],
132
+ t_start=t_enc,
133
+ unconditional_guidance_scale=7.5,
134
+ xtype='image',
135
+ first_ctype='vision',
136
+ second_ctype='prompt',
137
+ mixed_ratio=(1 - mixing),
138
+ )
139
+
140
+ z = z.to(device).half()
141
+ x = net.autokl_decode(z)
142
+ # Adjust color if needed
143
+ color_adj='None'
144
+ color_adj_flag = (color_adj != 'none') and (color_adj != 'None') and (color_adj is not None)
145
+ color_adj_simple = (color_adj == 'Simple') or color_adj == 'simple'
146
+ color_adj_keep_ratio = 0.5
147
+
148
+ if color_adj_flag and (ctype == 'vision'):
149
+ x_adj = []
150
+ for xi in x:
151
+ color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to)
152
+ xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple)
153
+ x_adj.append(xi_adj)
154
+ x = x_adj
155
+ else:
156
+ x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0)
157
+ x = [tvtrans.ToPILImage()(xi) for xi in x]
158
+
159
+ # Save output image
160
+ x[0].save('scripts/images/reconstructed.png'.format(sub, im_id))
161
+ # x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
162
+ # x = [tvtrans.ToPILImage()(xi) for xi in x]
163
+
164
+ output_path = f'scripts/images/reconstructed.png'
165
+ # x[0].save(output_path)
166
+
167
+ return test_img_path, output_path