eo-fm-bench / src /streamlit_app.py
Yao-Ting
Add map-plot into app
aa38f60
import pandas as pd
import plotly.express as px
import io
import numpy as np
import rasterio
import base64
from PIL import Image
import ast
from shapely import wkt
import streamlit as st
import plotly.express as px
from streamlit.components.v1 import html as st_html
import fsspec
import json
import s3fs
from rasterio.io import MemoryFile
import datetime
## app contents
# set page title and layout
st.set_page_config(
page_title="GFM Explainability Demo",
layout="wide",
)
# create background image: read image and base64-encode it
with open("data/sx_darkened_fields_v2.jpg", "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
bg_url = f"data:image/jpeg;base64,{b64}"
st.markdown(
f"""
<style>
.stApp {{
background-image: url("{bg_url}");
background-attachment: fixed;
background-size: cover;
background-repeat: no-repeat;
}}
</style>
""",
unsafe_allow_html=True,
)
st.markdown(
"""
<style>
.rainbow-text {
display: inline-block;
background: linear-gradient(
to right,
red, orange, yellow, lightblue, violet
);
background-clip: text;
-webkit-background-clip: text;
color: transparent !important;
}
</style>
""",
unsafe_allow_html=True,
)
col1, col2, col3 = st.columns([1, 5, 1])
with col2:
st.markdown(
"<h1 style='color:white;'>GFM Explainability Demo 🌎</h1>",
unsafe_allow_html=True
)
st.markdown(
"""
<p style='color:white; font-size:16px;'>
This app extracts t-SNE of Embeddings from image chips using
<span class="rainbow-text"><b>Prithvi-EO-2.0 model</b></span>!
</p>
""",
unsafe_allow_html=True
)
st.markdown(
"""
<style>
h4 {
color: rgb(204,156,172) !important;
text-decoration: non;
}
a.custom-link {
color: rgb(225,225,225) !important;
text-decoration: underline;
}
p {
color: rgb(225,225,225);
}
code {
color: #00C957 !important;
background-color: rgb(60, 60, 60) !important;
padding: 2px 4px;
border-radius: 3px;
}
</style>
<h4>Prithvi-EO-2.0</h4>
<p>
<code>Prithvi-EO-2.0</code> is the second-generation Geospatial Foundational Model developed by IBM, NASA, and the Jülich Supercomputing Centre. For more details, see the
<a
class="custom-link"
href="https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M"
target="_blank"
>NASA Hugging Face</a>.
</p>
<h4>Chips</h4>
<p>
The input chips for the pretrained Prithvi encoder are <code>224×224</code>, <code>6-band</code>, <code>4 time-step</code> Sentinel-2 imagery. Each chip contains a centered 64×64 patch representing a single pure land cover class, based on the Annual Land Use Land Cover (LULC) dataset. The dataset includes six categories: Water, Trees, Crops, Built Area, Bare Ground, and Rangeland.
Each time series generates an embedding map of shape [785, 1,024], where the first 784 vectors correspond to individual patches, and the 785th vector—the <code>CLS token</code>—represents the entire image. The second dimension (1,024) is the model’s embedding size, selected during pretraining to balance accuracy and computational efficiency.
</p>
<h4>t-SNE Transformation</h4>
<p>
To visualize the high-dimensional embeddings, a <code>t-SNE</code> transformation is applied to the 1,024-dimensional CLS tokens, reducing them to 2D while preserving relative distances between samples. Each sample is annotated with its corresponding land cover class.
</p>
""",
unsafe_allow_html=True
)
# st.markdown(
# """
# ### Prithvi-EO-2.0
# `Prithvi-EO-2.0` is the second generation Geospatial Foundational Model developed by IBM, NASA, and Jülich Supercomputing Centre. More details see [NASA Hugging Face]( https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M).
# ### Chips
# The chips utilized as inputs to a pretrained Prithvi encoder is `224x224`, `6-band`, `4 time step` Sentinel 2 imagery with a `64x64` patch of a pure land cover class in the center, determined by the Annual Land Use Land Cover (LULC) dataset with six categories—Water, Trees, Crops, Built area, Bare ground and Rangeland. The results for each time series are embedding maps of size [785, 1,024], in which the first 784 vectors of the first dimension are the embeddings for each patch, and the 785th vector is the embedding for the overall image, or `CLS token`. The second dimension of 1,024 is the model depth, and was set during pretraining to balance the tradeoffs of model complexity.
# ### t-SNE Transformation
# A `t-SNE` transformation is used to transform the 1,024 dimensions of the CLS token into 2 dimensions which preserve as much of the relative distance between samples as possible. The land cover class of each input sample is recorded and paired with the relevant input.
# """
# )
# read csv
chips_df = pd.read_csv("data/embeddings_df_v0.11_test.csv")
# set anonymous S3FileSystem to read files from public bucket
s3 = s3fs.S3FileSystem(anon=True)
def get_lat(geometry):
lat = wkt.loads(geometry).coords.xy[1][0]
return lat
def get_lon(geometry):
lon = wkt.loads(geometry).coords.xy[0][0]
return lon
## generate json
# title: plot title
# xaxis_title: x axis title
# yaxis_title: x axis title
config = {
"title" : "Visualization of EO-FM-Bench Embeddings",
"xaxis_title" : "t-SNE Dimension 1",
"yaxis_title" : "t-SNE Dimension 2",
}
# convert to json
title_js = json.dumps(config["title"])
xaxis_js = json.dumps(config["xaxis_title"])
yaxis_js = json.dumps(config["yaxis_title"])
# set lc(str) for categorical data for plotting
chips_df["lc"] = chips_df["lc"].astype(str)
# add latitude and longitude
chips_df["latitude"] = chips_df["geometry"].apply(get_lat)
chips_df["longitude"] = chips_df["geometry"].apply(get_lon)
# color dictionary
color_dict = {
'1': '#2c41e6', # Water
'2': '#04541b', # Trees
'5': '#99e0ad', # Crops
'7': '#797b85', # Built area
'8': '#a68647', # Bare ground
'11': '#f7980a', # Rangeland
}
# land cover dictionary
land_cover = {
'1': 'Water',
'2': 'Trees',
'5': 'Crops',
'7': 'Built area',
'8': 'Bare ground',
'11': 'Rangeland'
}
# add the legend column
chips_df['Land Cover'] = chips_df['lc'].map(land_cover)
# color dictionary with label
color_dict_label = {
'Water': '#2c41e6',
'Trees': '#04541b',
'Crops': '#99e0ad',
'Built area': "#111112",
'Bare ground': '#a68647',
'Rangeland': '#f7980a'
}
# create dates Python list
chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval)
# set prefix
s3_url="https://gfm-bench.s3.amazonaws.com/thumbnails"
# create thumb_urls column
chips_df["thumb_urls"] = chips_df.apply(
lambda r: [
f"{s3_url}/s2_{r.chip_id:06}_{date}.png"
for date in r.dates_list
],
axis=1
)
# build a list of points dictionary
points = (
chips_df
.rename(columns={
"cls_dim1": "x",
"cls_dim2": "y",
"Land Cover": "category"
})[["x","y","chip_id", "latitude", "longitude","category","dates_list"]]
.assign(
id = chips_df["chip_id"],
lat = chips_df["latitude"],
lon = chips_df["longitude"],
color=chips_df["Land Cover"].map(color_dict_label),
thumbs = chips_df["thumb_urls"])
.to_dict(orient="records")
)
# convert dictionary to json
points_json = json.dumps(points)
## build up plot and image container html
plot_html = f"""
<script src="https://cdn.plot.ly/plotly-2.18.1.min.js"></script>
<style>
html, body {{ margin:0; padding:0; height:100%; }}
#container {{ display:flex; width:100%; height:100%; }}
#map-plot, #scatter-plot {{ flex:1; padding:4px; box-sizing:border-box;margin-right:16px;}}
#scatter-plot {{
margin-right:6px;
}}
#image-container {{
display: grid;
grid-template-columns: repeat(2, 1fr);
flex: 0 0 400px;
height: 100%;
box-sizing: border-box;
padding: 4px;
grid-auto-rows: auto;
gap: 4px;
overflow: hidden;
}}
#image-container img {{ width:100%; height:auto; max-height:200px; }}
</style>
<div id="container">
<div id="map-plot"></div>
<div id="scatter-plot"></div>
<div id="image-container"></div>
</div>
<script>
const points = {points_json};
const cats = Array.from(new Set(points.map(p=>p.category)));
// 1) map traces
const mapTraces = cats.map(cat => {{
const pts = points.filter(p=>p.category===cat);
return {{
type: 'scattermapbox',
mode: 'markers',
name: cat,
x: pts.map(p=>p.x),
y: pts.map(p=>p.y),
id: pts.map(p=>p.id),
lat: pts.map(p=>p.lat),
lon: pts.map(p=>p.lon),
customdata:pts.map(p=>[
p.id,
p.x,
p.y
]),
marker: {{
color: pts.map(p=>p.color),
symbol: 'circle',
size: 8,
line: {{ color:'red', width:2 }}
}},
hovertemplate:
`<b>ID:</b> %{{customdata[0]}}<br>` +
`<b>x:</b> %{{customdata[1]:.4f}}<br>` +
`<b>y:</b> %{{customdata[2]:.4f}}<br>` +
`<b>lat:</b> %{{lat:.2f}}<br>` +
`<b>lon:</b> %{{lon:.2f}}<extra></extra>`,
selectedpoints: [],
selected: {{ marker: {{ color:'lightblue', size: 10 }} }},
unselected: {{ marker: {{ opacity:0.2 }} }}
}};
}});
const mapLayout = {{
mapbox: {{
style: 'mapbox://styles/mapbox/satellite-streets-v11',
center: {{ lon: 0, lat: 0 }},
zoom: 0,
accesstoken: 'pk.eyJ1IjoiY2xhcmtjZ2EteWF5YW8iLCJhIjoiY21jdDl0MDZoMDM3cjJscHBmcWpjbnhkaiJ9.YkEYejNsY5-r3DtESJ46kQ'
}},
clickmode: 'event+select',
margin: {{ l:0,r:0,t:0,b:0 }},
}};
Plotly.newPlot('map-plot', mapTraces, mapLayout, {{responsive:true}});
// 2) scatter traces: build one trace per category
const scatterTraces = cats.map(cat => {{
const pts = points.filter(p=>p.category===cat);
return {{
x: pts.map(p=>p.x),
y: pts.map(p=>p.y),
id: pts.map(p=>p.id),
lat: pts.map(p=>p.lat),
lon: pts.map(p=>p.lon),
customdata:pts.map(p=>[
p.id,
p.lat,
p.lon
]),
mode: 'markers',
type: 'scatter',
name: cat,
marker: {{
color: pts.map(p=>p.color),
size: 5,
line: {{ color:'black', width:1 }}
}},
hovertemplate:
`<b>ID:</b> %{{customdata[0]}}<br>` +
`<b>x:</b> %{{x:.2f}}<br>` +
`<b>y:</b> %{{y:.2f}}<br>` +
`<b>lat:</b> %{{customdata[1]:.4f}}<br>` +
`<b>lon:</b> %{{customdata[2]:.4f}}<extra></extra>`,
selectedpoints: [],
selected: {{ marker: {{ color:'lightblue', size: 10}} }},
unselected: {{ marker: {{ opacity:0.2 }} }}
}};
}});
const scatterLayout = {{
hovermode: 'closest',
title: {title_js},
xaxis: {{
title: {xaxis_js},
range: [-110, 110],
showgrid: true,
gridcolor: 'rgb(255,255,255)',
gridwidth: 1,
showline: false,
zeroline: false,
showticklabels: true,
ticks: 'outside',
tickcolor: 'rgb(127,127,127)'
}},
yaxis: {{
title: {yaxis_js},
range: [-110, 110],
showgrid: true,
gridcolor: 'rgb(255,255,255)',
gridwidth: 1,
showline: false,
zeroline: false,
showticklabels: true,
ticks: 'outside',
tickcolor: 'rgb(127,127,127)'
}},
paper_bgcolor: 'rgb(255,255,255)',
plot_bgcolor: 'rgb(234,234,242)',
autosize: true,
margin: {{ l:40, r:40, t:40, b:40 }},
clickmode:'event+select',
legend: {{ font:{{ size:12 }}, x:1.01, y:0.5 }}
}};
Plotly.newPlot('scatter-plot', scatterTraces, scatterLayout, {{responsive:true}});
// 3) click handler
function onPointClick(evt) {{
const pt = evt.points[0];
const idx = pt.pointIndex;
const traceNo = pt.curveNumber;
// highlight in both
Plotly.restyle('map-plot', {{ selectedpoints:[[idx]] }}, [traceNo]);
Plotly.restyle('scatter-plot', {{ selectedpoints:[[idx]] }}, [traceNo]);
// pull id from array
const clickedId = Array.isArray(pt.customdata)
// if
? pt.customdata[0]
//else
: pt.customdata;
// find the record
const record = points.find(p=>p.id===clickedId);
if(!record) return;
// show thumbnails
const thumbs = record.thumbs;
const dates = record.dates_list;
const cont = document.getElementById('image-container');
cont.innerHTML = '';
thumbs.forEach((url,i) => {{
if (!url) return;
const card = document.createElement('div');
card.style.textAlign = 'center';
card.style.marginBottom = '8px';
const img = document.createElement('img');
img.src = url;
img.style.width = '100%';
img.style.maxHeight = '180px';
const lbl = document.createElement('p');
lbl.textContent = dates[i];
lbl.style.color = 'white';
lbl.style.margin = '4px 0 0';
lbl.style.fontSize = '0.9em';
card.appendChild(img);
card.appendChild(lbl);
cont.appendChild(card);
}});
}}
document.getElementById('map-plot').on('plotly_click', onPointClick);
document.getElementById('scatter-plot').on('plotly_click', onPointClick);
</script>
"""
# embed into Streamlit
st_html(plot_html, height=500, width=2000, scrolling=False)
# build up footer html
year = datetime.datetime.now().year
footer_html = f"""
<style>
#footer {{
margin-top: 1rem;
color: rgb(204,156,172);
}}
#footer a {{
color: rgb(204,156,172);
text-decoration: underline;
}}
</style>
<div id="footer">
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; Background image credit: <b>Sitian Xiong</b>; image source: <a href="https://visibleearth.nasa.gov/images/152732/golden-fields-in-romania/152734l"><b>NASA Earth Observatory</b></a><br>
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; <b>Copyright &copy; {year} - Clark Center for Geospatial Analytics</b>
</div>
"""
# embed into Streamlit
col1, col2, col3 = st.columns([1, 5, 1])
with col2:
st.markdown(footer_html, unsafe_allow_html=True)