Yao-Ting Yao commited on
Commit
ec0e823
·
1 Parent(s): 8254c8e

Update streamlit_app.py

Browse files

Pre-generate thumbnails and save in AWS S3 bucket. Read thumbnails' url in click function.

Files changed (1) hide show
  1. src/streamlit_app.py +17 -105
src/streamlit_app.py CHANGED
@@ -142,94 +142,6 @@ chips_df = pd.read_csv("data/embeddings_df_v0.11_test.csv")
142
  # set anonymous S3FileSystem to read files from public bucket
143
  s3 = s3fs.S3FileSystem(anon=True)
144
 
145
- ## helper function
146
- def gen_chip_urls(row, s3_prefix):
147
- '''
148
- Generate S3 urls for chips
149
- :param row: dictionary with chip_id and dates
150
- :param s3_prefix: S3 url prefix
151
- :return s3_urls: a list of urls
152
- '''
153
- s3_urls = []
154
- dates = ast.literal_eval(row["dates"])
155
- for date in dates:
156
- filename = f"s2_{row['chip_id']:06}_{date}.tif"
157
- s3_url = f"{s3_prefix}/{filename}"
158
- s3_urls.append(s3_url)
159
- return s3_urls
160
-
161
- def mask_nodata(band, nodata_values=(-999,)):
162
- '''
163
- Mask nodata to nan
164
- :param band
165
- :param nodata_values:nodata values in chips is -999
166
- :return band
167
- '''
168
- band = band.astype(float)
169
- for val in nodata_values:
170
- band[band == val] = np.nan
171
- return band
172
-
173
- def normalize(band):
174
- '''
175
- Normalize a band to 0-1 range(float)
176
- :param band (ndarray)
177
- return normalize band (ndarray); when max equals min, returns zeros.
178
- '''
179
- if np.nanmean(band) >= 4000:
180
- band = band / 6000
181
- else:
182
- band = band / 4000
183
- band = np.clip(band, None, 1)
184
-
185
- return band
186
-
187
- def create_thumbnail(url):
188
- '''
189
- Read S3 file into memory, using rasterio to create a png thumbnail then encode as a base64 string url
190
- :param url: chip url
191
- :return a base64-encoded png string, returns an empty string when an error occurs
192
- '''
193
- try:
194
- # read raw bytes from s3 file
195
- with s3.open(url, "rb") as f:
196
- data = f.read()
197
-
198
- # wrap the raw bytes into an memory file
199
- with MemoryFile(data) as memfile:
200
-
201
- # read memory file with rasterio
202
- with memfile.open() as src:
203
- # mask nodata to have correct calculate normalization
204
- # band1->blue, band2->green, band3->red
205
-
206
- blue = src.read(1).astype(float)
207
- green = src.read(2).astype(float)
208
- red = src.read(3).astype(float)
209
-
210
- blue = normalize(mask_nodata(blue))
211
- green = normalize(mask_nodata(green))
212
- red = normalize(mask_nodata(red))
213
-
214
- # stack in RGB
215
- rgb = np.dstack((red, green, blue))
216
-
217
- # convert float(0-1) to uint8 (0-255)
218
- rgb_8bit = (rgb * 255).astype(np.uint8)
219
-
220
- # convert to png in memory
221
- pil_img = Image.fromarray(rgb_8bit)
222
- buf = io.BytesIO()
223
- pil_img.save(buf, format='PNG')
224
-
225
- # encoded into base64 HTML format
226
- encoded = base64.b64encode(buf.getvalue()).decode('utf-8')
227
- return f"data:image/png;base64,{encoded}"
228
-
229
- except Exception as e:
230
- # return an empty string for Exception
231
- return ""
232
-
233
  def get_lat(geometry):
234
  lat = wkt.loads(geometry).coords.xy[1][0]
235
 
@@ -240,7 +152,6 @@ def get_lon(geometry):
240
 
241
  return lon
242
 
243
-
244
  ## generate json
245
  # title: plot title
246
  # xaxis_title: x axis title
@@ -255,12 +166,6 @@ title_js = json.dumps(config["title"])
255
  xaxis_js = json.dumps(config["xaxis_title"])
256
  yaxis_js = json.dumps(config["yaxis_title"])
257
 
258
- # set prefix
259
- s3_prefix="s3://gfm-bench"
260
-
261
- # generate S3 file URLs
262
- chips_df["urls"] = chips_df.apply(lambda row: gen_chip_urls(row, s3_prefix), axis=1)
263
-
264
  # set lc(str) for categorical data for plotting
265
  chips_df["lc"] = chips_df["lc"].astype(str)
266
  # add latitude and longitude
@@ -300,13 +205,21 @@ color_dict_label = {
300
  'Rangeland': '#f7980a'
301
  }
302
 
303
- # create thumbnail
304
- chips_df["thumbs"] = chips_df["urls"].apply(
305
- lambda urls: [create_thumbnail(p) for p in urls]
306
- )
307
  # create dates Python list
308
  chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval)
309
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # build a list of points dictionary
311
  points = (
312
  chips_df
@@ -314,12 +227,13 @@ points = (
314
  "cls_dim1": "x",
315
  "cls_dim2": "y",
316
  "Land Cover": "category"
317
- })[["x","y","chip_id", "latitude", "longitude","category","thumbs","dates_list"]]
318
  .assign(
319
  id = chips_df["chip_id"],
320
  lat = chips_df["latitude"],
321
  lon = chips_df["longitude"],
322
- color=chips_df["Land Cover"].map(color_dict_label))
 
323
  .to_dict(orient="records")
324
  )
325
 
@@ -380,8 +294,7 @@ plot_html = f"""
380
  customdata:pts.map(p=>[
381
  p.id,
382
  p.lat,
383
- p.lon,
384
- p.thumbs
385
  ]),
386
  mode: 'markers',
387
  type: 'scatter',
@@ -438,8 +351,7 @@ plot_html = f"""
438
  gd.on('plotly_click', evt => {{
439
  // grab thumbs and dates through point index
440
  const idx = evt.points[0].pointIndex;
441
- const cds = evt.points[0].customdata;
442
- const thumbs = cds[3];
443
  const dates = points[idx].dates_list;
444
  // grab image container and clear out old thumbs
445
  const container = document.getElementById('image-container');
 
142
  # set anonymous S3FileSystem to read files from public bucket
143
  s3 = s3fs.S3FileSystem(anon=True)
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def get_lat(geometry):
146
  lat = wkt.loads(geometry).coords.xy[1][0]
147
 
 
152
 
153
  return lon
154
 
 
155
  ## generate json
156
  # title: plot title
157
  # xaxis_title: x axis title
 
166
  xaxis_js = json.dumps(config["xaxis_title"])
167
  yaxis_js = json.dumps(config["yaxis_title"])
168
 
 
 
 
 
 
 
169
  # set lc(str) for categorical data for plotting
170
  chips_df["lc"] = chips_df["lc"].astype(str)
171
  # add latitude and longitude
 
205
  'Rangeland': '#f7980a'
206
  }
207
 
 
 
 
 
208
  # create dates Python list
209
  chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval)
210
 
211
+ # set prefix
212
+ s3_url="https://gfm-bench.s3.amazonaws.com/thumbnails"
213
+
214
+ # create thumb_urls column
215
+ chips_df["thumb_urls"] = chips_df.apply(
216
+ lambda r: [
217
+ f"{s3_url}/s2_{r.chip_id:06}_{date}.png"
218
+ for date in r.dates_list
219
+ ],
220
+ axis=1
221
+ )
222
+
223
  # build a list of points dictionary
224
  points = (
225
  chips_df
 
227
  "cls_dim1": "x",
228
  "cls_dim2": "y",
229
  "Land Cover": "category"
230
+ })[["x","y","chip_id", "latitude", "longitude","category","dates_list"]]
231
  .assign(
232
  id = chips_df["chip_id"],
233
  lat = chips_df["latitude"],
234
  lon = chips_df["longitude"],
235
+ color=chips_df["Land Cover"].map(color_dict_label),
236
+ thumbs = chips_df["thumb_urls"])
237
  .to_dict(orient="records")
238
  )
239
 
 
294
  customdata:pts.map(p=>[
295
  p.id,
296
  p.lat,
297
+ p.lon
 
298
  ]),
299
  mode: 'markers',
300
  type: 'scatter',
 
351
  gd.on('plotly_click', evt => {{
352
  // grab thumbs and dates through point index
353
  const idx = evt.points[0].pointIndex;
354
+ const thumbs = points[idx].thumbs;
 
355
  const dates = points[idx].dates_list;
356
  // grab image container and clear out old thumbs
357
  const container = document.getElementById('image-container');