import pandas as pd import plotly.express as px import io import numpy as np import rasterio import base64 from PIL import Image import ast from shapely import wkt import streamlit as st import plotly.express as px from streamlit.components.v1 import html as st_html import fsspec import json import s3fs from rasterio.io import MemoryFile import datetime ## app contents # set page title and layout st.set_page_config( page_title="GFM Explainability Demo", layout="wide", ) # create background image: read image and base64-encode it with open("data/sx_darkened_fields_v2.jpg", "rb") as f: b64 = base64.b64encode(f.read()).decode("utf-8") bg_url = f"data:image/jpeg;base64,{b64}" st.markdown( f""" """, unsafe_allow_html=True, ) st.markdown( """ """, unsafe_allow_html=True, ) col1, col2, col3 = st.columns([1, 5, 1]) with col2: st.markdown( "

GFM Explainability Demo 🌎

", unsafe_allow_html=True ) st.markdown( """

This app extracts t-SNE of Embeddings from image chips using Prithvi-EO-2.0 model!

""", unsafe_allow_html=True ) st.markdown( """

Prithvi-EO-2.0

Prithvi-EO-2.0 is the second-generation Geospatial Foundational Model developed by IBM, NASA, and the Jülich Supercomputing Centre. For more details, see the NASA Hugging Face.

Chips

The input chips for the pretrained Prithvi encoder are 224×224, 6-band, 4 time-step Sentinel-2 imagery. Each chip contains a centered 64×64 patch representing a single pure land cover class, based on the Annual Land Use Land Cover (LULC) dataset. The dataset includes six categories: Water, Trees, Crops, Built Area, Bare Ground, and Rangeland. Each time series generates an embedding map of shape [785, 1,024], where the first 784 vectors correspond to individual patches, and the 785th vector—the CLS token—represents the entire image. The second dimension (1,024) is the model’s embedding size, selected during pretraining to balance accuracy and computational efficiency.

t-SNE Transformation

To visualize the high-dimensional embeddings, a t-SNE transformation is applied to the 1,024-dimensional CLS tokens, reducing them to 2D while preserving relative distances between samples. Each sample is annotated with its corresponding land cover class.

""", unsafe_allow_html=True ) # st.markdown( # """ # ### Prithvi-EO-2.0 # `Prithvi-EO-2.0` is the second generation Geospatial Foundational Model developed by IBM, NASA, and Jülich Supercomputing Centre. More details see [NASA Hugging Face]( https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M). # ### Chips # The chips utilized as inputs to a pretrained Prithvi encoder is `224x224`, `6-band`, `4 time step` Sentinel 2 imagery with a `64x64` patch of a pure land cover class in the center, determined by the Annual Land Use Land Cover (LULC) dataset with six categories—Water, Trees, Crops, Built area, Bare ground and Rangeland. The results for each time series are embedding maps of size [785, 1,024], in which the first 784 vectors of the first dimension are the embeddings for each patch, and the 785th vector is the embedding for the overall image, or `CLS token`. The second dimension of 1,024 is the model depth, and was set during pretraining to balance the tradeoffs of model complexity. # ### t-SNE Transformation # A `t-SNE` transformation is used to transform the 1,024 dimensions of the CLS token into 2 dimensions which preserve as much of the relative distance between samples as possible. The land cover class of each input sample is recorded and paired with the relevant input. # """ # ) # read csv chips_df = pd.read_csv("data/embeddings_df_v0.11_test.csv") # set anonymous S3FileSystem to read files from public bucket s3 = s3fs.S3FileSystem(anon=True) def get_lat(geometry): lat = wkt.loads(geometry).coords.xy[1][0] return lat def get_lon(geometry): lon = wkt.loads(geometry).coords.xy[0][0] return lon ## generate json # title: plot title # xaxis_title: x axis title # yaxis_title: x axis title config = { "title" : "Visualization of EO-FM-Bench Embeddings", "xaxis_title" : "t-SNE Dimension 1", "yaxis_title" : "t-SNE Dimension 2", } # convert to json title_js = json.dumps(config["title"]) xaxis_js = json.dumps(config["xaxis_title"]) yaxis_js = json.dumps(config["yaxis_title"]) # set lc(str) for categorical data for plotting chips_df["lc"] = chips_df["lc"].astype(str) # add latitude and longitude chips_df["latitude"] = chips_df["geometry"].apply(get_lat) chips_df["longitude"] = chips_df["geometry"].apply(get_lon) # color dictionary color_dict = { '1': '#2c41e6', # Water '2': '#04541b', # Trees '5': '#99e0ad', # Crops '7': '#797b85', # Built area '8': '#a68647', # Bare ground '11': '#f7980a', # Rangeland } # land cover dictionary land_cover = { '1': 'Water', '2': 'Trees', '5': 'Crops', '7': 'Built area', '8': 'Bare ground', '11': 'Rangeland' } # add the legend column chips_df['Land Cover'] = chips_df['lc'].map(land_cover) # color dictionary with label color_dict_label = { 'Water': '#2c41e6', 'Trees': '#04541b', 'Crops': '#99e0ad', 'Built area': "#111112", 'Bare ground': '#a68647', 'Rangeland': '#f7980a' } # create dates Python list chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval) # set prefix s3_url="https://gfm-bench.s3.amazonaws.com/thumbnails" # create thumb_urls column chips_df["thumb_urls"] = chips_df.apply( lambda r: [ f"{s3_url}/s2_{r.chip_id:06}_{date}.png" for date in r.dates_list ], axis=1 ) # build a list of points dictionary points = ( chips_df .rename(columns={ "cls_dim1": "x", "cls_dim2": "y", "Land Cover": "category" })[["x","y","chip_id", "latitude", "longitude","category","dates_list"]] .assign( id = chips_df["chip_id"], lat = chips_df["latitude"], lon = chips_df["longitude"], color=chips_df["Land Cover"].map(color_dict_label), thumbs = chips_df["thumb_urls"]) .to_dict(orient="records") ) # convert dictionary to json points_json = json.dumps(points) ## build up plot and image container html plot_html = f"""
""" # embed into Streamlit st_html(plot_html, height=500, width=2000, scrolling=False) # build up footer html year = datetime.datetime.now().year footer_html = f""" """ # embed into Streamlit col1, col2, col3 = st.columns([1, 5, 1]) with col2: st.markdown(footer_html, unsafe_allow_html=True)