Spaces:
Running
Running
import pandas as pd | |
import numpy as np | |
from PIL import Image | |
import ast | |
import s3fs | |
from rasterio.io import MemoryFile | |
import os | |
# read csv | |
chips_df = pd.read_csv("../data/embeddings_df_v0.11_test.csv") | |
# set anonymous S3FileSystem to read files from public bucket | |
s3 = s3fs.S3FileSystem(anon=True) | |
## helper function | |
def gen_chip_urls(row, s3_prefix): | |
''' | |
Generate S3 urls for chips | |
:param row: dictionary with chip_id and dates | |
:param s3_prefix: S3 url prefix | |
:return s3_urls: a list of urls | |
''' | |
s3_urls = [] | |
dates = ast.literal_eval(row["dates"]) | |
for date in dates: | |
filename = f"s2_{row['chip_id']:06}_{date}.tif" | |
s3_url = f"{s3_prefix}/{filename}" | |
s3_urls.append(s3_url) | |
return s3_urls | |
def mask_nodata(band, nodata_values=(-999,)): | |
''' | |
Mask nodata to nan | |
:param band | |
:param nodata_values:nodata values in chips is -999 | |
:return band | |
''' | |
band = band.astype(float) | |
for val in nodata_values: | |
band[band == val] = np.nan | |
return band | |
def normalize(band): | |
''' | |
Normalize a band to 0-1 range(float) | |
:param band (ndarray) | |
return normalize band (ndarray); when max equals min, returns zeros. | |
''' | |
if np.nanmean(band) >= 4000: | |
band = band / 6000 | |
else: | |
band = band / 4000 | |
band = np.clip(band, None, 1) | |
return band | |
def create_thumbnail(url, output_dir): | |
''' | |
Read S3 file into memory, create and save a resized png thumbnail. | |
:param url: S3 file URL | |
:param output_dir: directory to save thumbnails | |
:return: saved file path (str) or "" if failed | |
''' | |
try: | |
os.makedirs(output_dir, exist_ok=True) | |
# read raw bytes from s3 file | |
with s3.open(url, "rb") as f: | |
data = f.read() | |
# wrap the raw bytes into an memory file | |
with MemoryFile(data) as memfile: | |
# read memory file with rasterio | |
with memfile.open() as src: | |
# mask nodata to have correct calculate normalization | |
# band1->blue, band2->green, band3->red | |
blue = src.read(1).astype(float) | |
green = src.read(2).astype(float) | |
red = src.read(3).astype(float) | |
blue = normalize(mask_nodata(blue)) | |
green = normalize(mask_nodata(green)) | |
red = normalize(mask_nodata(red)) | |
# stack in RGB | |
rgb = np.dstack((red, green, blue)) | |
# convert float(0-1) to uint8 (0-255) | |
rgb_8bit = (rgb * 255).astype(np.uint8) | |
# convert to png in memory | |
pil_img = Image.fromarray(rgb_8bit) | |
# save png to local | |
filename = os.path.basename(url).replace(".tif", ".png") | |
file_path = os.path.join(output_dir, filename) | |
pil_img.save(file_path, format="PNG") | |
return file_path | |
except Exception as e: | |
# return an empty string for Exception | |
return "" | |
# set prefix | |
s3_prefix="s3://gfm-bench" | |
# generate S3 file URLs | |
chips_df["urls"] = chips_df.apply(lambda row: gen_chip_urls(row, s3_prefix), axis=1) | |
# create thumbnail | |
chips_df["thumbs"] = chips_df["urls"].apply( | |
lambda urls: [create_thumbnail(p, output_dir="../data/thumbnails") for p in urls] | |
) |