import pandas as pd import plotly.express as px import io import numpy as np import rasterio import base64 from PIL import Image import ast from shapely import wkt import streamlit as st import plotly.express as px from streamlit.components.v1 import html as st_html import fsspec import json import s3fs from rasterio.io import MemoryFile import datetime ## app contents # set page title and layout st.set_page_config( page_title="GFM Explainability Demo", layout="wide", ) # create background image: read image and base64-encode it with open("data/sx_darkened_fields_v2.jpg", "rb") as f: b64 = base64.b64encode(f.read()).decode("utf-8") bg_url = f"data:image/jpeg;base64,{b64}" st.markdown( f""" """, unsafe_allow_html=True, ) st.markdown( """ """, unsafe_allow_html=True, ) col1, col2, col3 = st.columns([1, 5, 1]) with col2: st.markdown( "

GFM Explainability Demo 🌎

", unsafe_allow_html=True ) st.markdown( """

This app extracts t-SNE of Embeddings from image chips using Prithvi-EO-2.0 model!

""", unsafe_allow_html=True ) # read csv chips_df = pd.read_csv("data/embeddings_df_v0.11_test.csv") # set anonymous S3FileSystem to read files from public bucket s3 = s3fs.S3FileSystem(anon=True) ## helper function def gen_chip_urls(row, s3_prefix): ''' Generate S3 urls for chips :param row: dictionary with chip_id and dates :param s3_prefix: S3 url prefix :return s3_urls: a list of urls ''' s3_urls = [] dates = ast.literal_eval(row["dates"]) for date in dates: filename = f"s2_{row['chip_id']:06}_{date}.tif" s3_url = f"{s3_prefix}/{filename}" s3_urls.append(s3_url) return s3_urls def mask_nodata(band, nodata_values=(-999,)): ''' Mask nodata to nan :param band :param nodata_values:nodata values in chips is -999 :return band ''' band = band.astype(float) for val in nodata_values: band[band == val] = np.nan return band def normalize(band): ''' Normalize a band to 0-1 range(float) :param band (ndarray) return normalize band (ndarray); when max equals min, returns zeros. ''' if np.nanmean(band) >= 4000: band = band / 6000 else: band = band / 4000 band = np.clip(band, None, 1) return band def create_thumbnail(url): ''' Read S3 file into memory, using rasterio to create a png thumbnail then encode as a base64 string url :param url: chip url :return a base64-encoded png string, returns an empty string when an error occurs ''' try: # read raw bytes from s3 file with s3.open(url, "rb") as f: data = f.read() # wrap the raw bytes into an memory file with MemoryFile(data) as memfile: # read memory file with rasterio with memfile.open() as src: # mask nodata to have correct calculate normalization # band1->blue, band2->green, band3->red blue = src.read(1).astype(float) green = src.read(2).astype(float) red = src.read(3).astype(float) blue = normalize(mask_nodata(blue)) green = normalize(mask_nodata(green)) red = normalize(mask_nodata(red)) # stack in RGB rgb = np.dstack((red, green, blue)) # convert float(0-1) to uint8 (0-255) rgb_8bit = (rgb * 255).astype(np.uint8) # convert to png in memory pil_img = Image.fromarray(rgb_8bit) buf = io.BytesIO() pil_img.save(buf, format='PNG') # encoded into base64 HTML format encoded = base64.b64encode(buf.getvalue()).decode('utf-8') return f"data:image/png;base64,{encoded}" except Exception as e: # return an empty string for Exception return "" def get_lat(geometry): lat = wkt.loads(geometry).coords.xy[1][0] return lat def get_lon(geometry): lon = wkt.loads(geometry).coords.xy[0][0] return lon ## generate json # title: plot title # xaxis_title: x axis title # yaxis_title: x axis title config = { "title" : "t-SNE Visualization of EO-FM-Bench Embeddings for Prithvi-EO-2.0", "xaxis_title" : "t-SNE Dimension 1", "yaxis_title" : "t-SNE Dimension 2", } # convert to json title_js = json.dumps(config["title"]) xaxis_js = json.dumps(config["xaxis_title"]) yaxis_js = json.dumps(config["yaxis_title"]) # set prefix s3_prefix="s3://gfm-bench" # generate S3 file URLs chips_df["urls"] = chips_df.apply(lambda row: gen_chip_urls(row, s3_prefix), axis=1) # set lc(str) for categorical data for plotting chips_df["lc"] = chips_df["lc"].astype(str) # add latitude and longitude chips_df["latitude"] = chips_df["geometry"].apply(get_lat) chips_df["longitude"] = chips_df["geometry"].apply(get_lon) # color dictionary color_dict = { '1': '#2c41e6', # Water '2': '#04541b', # Trees '5': '#99e0ad', # Crops '7': '#797b85', # Built area '8': '#a68647', # Bare ground '11': '#f7980a', # Rangeland } # land cover dictionary land_cover = { '1': 'Water', '2': 'Trees', '5': 'Crops', '7': 'Built area', '8': 'Bare ground', '11': 'Rangeland' } # add the legend column chips_df['Land Cover'] = chips_df['lc'].map(land_cover) # color dictionary with label color_dict_label = { 'Water': '#2c41e6', 'Trees': '#04541b', 'Crops': '#99e0ad', 'Built area': '#797b85', 'Bare ground': '#a68647', 'Rangeland': '#f7980a' } # create thumbnail chips_df["thumbs"] = chips_df["urls"].apply( lambda urls: [create_thumbnail(p) for p in urls] ) # create dates Python list chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval) # build a list of points dictionary points = ( chips_df .rename(columns={ "cls_dim1": "x", "cls_dim2": "y", "Land Cover": "category" })[["x","y","category","thumbs","dates_list"]] .assign(color=chips_df["Land Cover"].map(color_dict_label)) .to_dict(orient="records") ) # convert dictionary to json points_json = json.dumps(points) ## build up plot and image container html plot_html = f"""
""" # build up footer html year = datetime.datetime.now().year footer_html = f""" """ # embed into Streamlit col1, col2, col3 = st.columns([1, 5, 1]) with col2: st_html(plot_html, height=500, width=1000, scrolling=True) st.markdown(footer_html, unsafe_allow_html=True)