eo-fm-bench / src /thumbnail_process.py
Yao-Ting Yao
Create thumbnail_process.py
8254c8e
import pandas as pd
import numpy as np
from PIL import Image
import ast
import s3fs
from rasterio.io import MemoryFile
import os
# read csv
chips_df = pd.read_csv("../data/embeddings_df_v0.11_test.csv")
# set anonymous S3FileSystem to read files from public bucket
s3 = s3fs.S3FileSystem(anon=True)
## helper function
def gen_chip_urls(row, s3_prefix):
'''
Generate S3 urls for chips
:param row: dictionary with chip_id and dates
:param s3_prefix: S3 url prefix
:return s3_urls: a list of urls
'''
s3_urls = []
dates = ast.literal_eval(row["dates"])
for date in dates:
filename = f"s2_{row['chip_id']:06}_{date}.tif"
s3_url = f"{s3_prefix}/{filename}"
s3_urls.append(s3_url)
return s3_urls
def mask_nodata(band, nodata_values=(-999,)):
'''
Mask nodata to nan
:param band
:param nodata_values:nodata values in chips is -999
:return band
'''
band = band.astype(float)
for val in nodata_values:
band[band == val] = np.nan
return band
def normalize(band):
'''
Normalize a band to 0-1 range(float)
:param band (ndarray)
return normalize band (ndarray); when max equals min, returns zeros.
'''
if np.nanmean(band) >= 4000:
band = band / 6000
else:
band = band / 4000
band = np.clip(band, None, 1)
return band
def create_thumbnail(url, output_dir):
'''
Read S3 file into memory, create and save a resized png thumbnail.
:param url: S3 file URL
:param output_dir: directory to save thumbnails
:return: saved file path (str) or "" if failed
'''
try:
os.makedirs(output_dir, exist_ok=True)
# read raw bytes from s3 file
with s3.open(url, "rb") as f:
data = f.read()
# wrap the raw bytes into an memory file
with MemoryFile(data) as memfile:
# read memory file with rasterio
with memfile.open() as src:
# mask nodata to have correct calculate normalization
# band1->blue, band2->green, band3->red
blue = src.read(1).astype(float)
green = src.read(2).astype(float)
red = src.read(3).astype(float)
blue = normalize(mask_nodata(blue))
green = normalize(mask_nodata(green))
red = normalize(mask_nodata(red))
# stack in RGB
rgb = np.dstack((red, green, blue))
# convert float(0-1) to uint8 (0-255)
rgb_8bit = (rgb * 255).astype(np.uint8)
# convert to png in memory
pil_img = Image.fromarray(rgb_8bit)
# save png to local
filename = os.path.basename(url).replace(".tif", ".png")
file_path = os.path.join(output_dir, filename)
pil_img.save(file_path, format="PNG")
return file_path
except Exception as e:
# return an empty string for Exception
return ""
# set prefix
s3_prefix="s3://gfm-bench"
# generate S3 file URLs
chips_df["urls"] = chips_df.apply(lambda row: gen_chip_urls(row, s3_prefix), axis=1)
# create thumbnail
chips_df["thumbs"] = chips_df["urls"].apply(
lambda urls: [create_thumbnail(p, output_dir="../data/thumbnails") for p in urls]
)