Spaces:
Running
Running
Yao-Ting Yao
commited on
Commit
·
ec0e823
1
Parent(s):
8254c8e
Update streamlit_app.py
Browse filesPre-generate thumbnails and save in AWS S3 bucket. Read thumbnails' url in click function.
- src/streamlit_app.py +17 -105
src/streamlit_app.py
CHANGED
@@ -142,94 +142,6 @@ chips_df = pd.read_csv("data/embeddings_df_v0.11_test.csv")
|
|
142 |
# set anonymous S3FileSystem to read files from public bucket
|
143 |
s3 = s3fs.S3FileSystem(anon=True)
|
144 |
|
145 |
-
## helper function
|
146 |
-
def gen_chip_urls(row, s3_prefix):
|
147 |
-
'''
|
148 |
-
Generate S3 urls for chips
|
149 |
-
:param row: dictionary with chip_id and dates
|
150 |
-
:param s3_prefix: S3 url prefix
|
151 |
-
:return s3_urls: a list of urls
|
152 |
-
'''
|
153 |
-
s3_urls = []
|
154 |
-
dates = ast.literal_eval(row["dates"])
|
155 |
-
for date in dates:
|
156 |
-
filename = f"s2_{row['chip_id']:06}_{date}.tif"
|
157 |
-
s3_url = f"{s3_prefix}/{filename}"
|
158 |
-
s3_urls.append(s3_url)
|
159 |
-
return s3_urls
|
160 |
-
|
161 |
-
def mask_nodata(band, nodata_values=(-999,)):
|
162 |
-
'''
|
163 |
-
Mask nodata to nan
|
164 |
-
:param band
|
165 |
-
:param nodata_values:nodata values in chips is -999
|
166 |
-
:return band
|
167 |
-
'''
|
168 |
-
band = band.astype(float)
|
169 |
-
for val in nodata_values:
|
170 |
-
band[band == val] = np.nan
|
171 |
-
return band
|
172 |
-
|
173 |
-
def normalize(band):
|
174 |
-
'''
|
175 |
-
Normalize a band to 0-1 range(float)
|
176 |
-
:param band (ndarray)
|
177 |
-
return normalize band (ndarray); when max equals min, returns zeros.
|
178 |
-
'''
|
179 |
-
if np.nanmean(band) >= 4000:
|
180 |
-
band = band / 6000
|
181 |
-
else:
|
182 |
-
band = band / 4000
|
183 |
-
band = np.clip(band, None, 1)
|
184 |
-
|
185 |
-
return band
|
186 |
-
|
187 |
-
def create_thumbnail(url):
|
188 |
-
'''
|
189 |
-
Read S3 file into memory, using rasterio to create a png thumbnail then encode as a base64 string url
|
190 |
-
:param url: chip url
|
191 |
-
:return a base64-encoded png string, returns an empty string when an error occurs
|
192 |
-
'''
|
193 |
-
try:
|
194 |
-
# read raw bytes from s3 file
|
195 |
-
with s3.open(url, "rb") as f:
|
196 |
-
data = f.read()
|
197 |
-
|
198 |
-
# wrap the raw bytes into an memory file
|
199 |
-
with MemoryFile(data) as memfile:
|
200 |
-
|
201 |
-
# read memory file with rasterio
|
202 |
-
with memfile.open() as src:
|
203 |
-
# mask nodata to have correct calculate normalization
|
204 |
-
# band1->blue, band2->green, band3->red
|
205 |
-
|
206 |
-
blue = src.read(1).astype(float)
|
207 |
-
green = src.read(2).astype(float)
|
208 |
-
red = src.read(3).astype(float)
|
209 |
-
|
210 |
-
blue = normalize(mask_nodata(blue))
|
211 |
-
green = normalize(mask_nodata(green))
|
212 |
-
red = normalize(mask_nodata(red))
|
213 |
-
|
214 |
-
# stack in RGB
|
215 |
-
rgb = np.dstack((red, green, blue))
|
216 |
-
|
217 |
-
# convert float(0-1) to uint8 (0-255)
|
218 |
-
rgb_8bit = (rgb * 255).astype(np.uint8)
|
219 |
-
|
220 |
-
# convert to png in memory
|
221 |
-
pil_img = Image.fromarray(rgb_8bit)
|
222 |
-
buf = io.BytesIO()
|
223 |
-
pil_img.save(buf, format='PNG')
|
224 |
-
|
225 |
-
# encoded into base64 HTML format
|
226 |
-
encoded = base64.b64encode(buf.getvalue()).decode('utf-8')
|
227 |
-
return f"data:image/png;base64,{encoded}"
|
228 |
-
|
229 |
-
except Exception as e:
|
230 |
-
# return an empty string for Exception
|
231 |
-
return ""
|
232 |
-
|
233 |
def get_lat(geometry):
|
234 |
lat = wkt.loads(geometry).coords.xy[1][0]
|
235 |
|
@@ -240,7 +152,6 @@ def get_lon(geometry):
|
|
240 |
|
241 |
return lon
|
242 |
|
243 |
-
|
244 |
## generate json
|
245 |
# title: plot title
|
246 |
# xaxis_title: x axis title
|
@@ -255,12 +166,6 @@ title_js = json.dumps(config["title"])
|
|
255 |
xaxis_js = json.dumps(config["xaxis_title"])
|
256 |
yaxis_js = json.dumps(config["yaxis_title"])
|
257 |
|
258 |
-
# set prefix
|
259 |
-
s3_prefix="s3://gfm-bench"
|
260 |
-
|
261 |
-
# generate S3 file URLs
|
262 |
-
chips_df["urls"] = chips_df.apply(lambda row: gen_chip_urls(row, s3_prefix), axis=1)
|
263 |
-
|
264 |
# set lc(str) for categorical data for plotting
|
265 |
chips_df["lc"] = chips_df["lc"].astype(str)
|
266 |
# add latitude and longitude
|
@@ -300,13 +205,21 @@ color_dict_label = {
|
|
300 |
'Rangeland': '#f7980a'
|
301 |
}
|
302 |
|
303 |
-
# create thumbnail
|
304 |
-
chips_df["thumbs"] = chips_df["urls"].apply(
|
305 |
-
lambda urls: [create_thumbnail(p) for p in urls]
|
306 |
-
)
|
307 |
# create dates Python list
|
308 |
chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval)
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
# build a list of points dictionary
|
311 |
points = (
|
312 |
chips_df
|
@@ -314,12 +227,13 @@ points = (
|
|
314 |
"cls_dim1": "x",
|
315 |
"cls_dim2": "y",
|
316 |
"Land Cover": "category"
|
317 |
-
})[["x","y","chip_id", "latitude", "longitude","category","
|
318 |
.assign(
|
319 |
id = chips_df["chip_id"],
|
320 |
lat = chips_df["latitude"],
|
321 |
lon = chips_df["longitude"],
|
322 |
-
color=chips_df["Land Cover"].map(color_dict_label)
|
|
|
323 |
.to_dict(orient="records")
|
324 |
)
|
325 |
|
@@ -380,8 +294,7 @@ plot_html = f"""
|
|
380 |
customdata:pts.map(p=>[
|
381 |
p.id,
|
382 |
p.lat,
|
383 |
-
p.lon
|
384 |
-
p.thumbs
|
385 |
]),
|
386 |
mode: 'markers',
|
387 |
type: 'scatter',
|
@@ -438,8 +351,7 @@ plot_html = f"""
|
|
438 |
gd.on('plotly_click', evt => {{
|
439 |
// grab thumbs and dates through point index
|
440 |
const idx = evt.points[0].pointIndex;
|
441 |
-
const
|
442 |
-
const thumbs = cds[3];
|
443 |
const dates = points[idx].dates_list;
|
444 |
// grab image container and clear out old thumbs
|
445 |
const container = document.getElementById('image-container');
|
|
|
142 |
# set anonymous S3FileSystem to read files from public bucket
|
143 |
s3 = s3fs.S3FileSystem(anon=True)
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def get_lat(geometry):
|
146 |
lat = wkt.loads(geometry).coords.xy[1][0]
|
147 |
|
|
|
152 |
|
153 |
return lon
|
154 |
|
|
|
155 |
## generate json
|
156 |
# title: plot title
|
157 |
# xaxis_title: x axis title
|
|
|
166 |
xaxis_js = json.dumps(config["xaxis_title"])
|
167 |
yaxis_js = json.dumps(config["yaxis_title"])
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
# set lc(str) for categorical data for plotting
|
170 |
chips_df["lc"] = chips_df["lc"].astype(str)
|
171 |
# add latitude and longitude
|
|
|
205 |
'Rangeland': '#f7980a'
|
206 |
}
|
207 |
|
|
|
|
|
|
|
|
|
208 |
# create dates Python list
|
209 |
chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval)
|
210 |
|
211 |
+
# set prefix
|
212 |
+
s3_url="https://gfm-bench.s3.amazonaws.com/thumbnails"
|
213 |
+
|
214 |
+
# create thumb_urls column
|
215 |
+
chips_df["thumb_urls"] = chips_df.apply(
|
216 |
+
lambda r: [
|
217 |
+
f"{s3_url}/s2_{r.chip_id:06}_{date}.png"
|
218 |
+
for date in r.dates_list
|
219 |
+
],
|
220 |
+
axis=1
|
221 |
+
)
|
222 |
+
|
223 |
# build a list of points dictionary
|
224 |
points = (
|
225 |
chips_df
|
|
|
227 |
"cls_dim1": "x",
|
228 |
"cls_dim2": "y",
|
229 |
"Land Cover": "category"
|
230 |
+
})[["x","y","chip_id", "latitude", "longitude","category","dates_list"]]
|
231 |
.assign(
|
232 |
id = chips_df["chip_id"],
|
233 |
lat = chips_df["latitude"],
|
234 |
lon = chips_df["longitude"],
|
235 |
+
color=chips_df["Land Cover"].map(color_dict_label),
|
236 |
+
thumbs = chips_df["thumb_urls"])
|
237 |
.to_dict(orient="records")
|
238 |
)
|
239 |
|
|
|
294 |
customdata:pts.map(p=>[
|
295 |
p.id,
|
296 |
p.lat,
|
297 |
+
p.lon
|
|
|
298 |
]),
|
299 |
mode: 'markers',
|
300 |
type: 'scatter',
|
|
|
351 |
gd.on('plotly_click', evt => {{
|
352 |
// grab thumbs and dates through point index
|
353 |
const idx = evt.points[0].pointIndex;
|
354 |
+
const thumbs = points[idx].thumbs;
|
|
|
355 |
const dates = points[idx].dates_list;
|
356 |
// grab image container and clear out old thumbs
|
357 |
const container = document.getElementById('image-container');
|