Spaces:
Running
Running
import pandas as pd | |
import plotly.express as px | |
import io | |
import numpy as np | |
import rasterio | |
import base64 | |
from PIL import Image | |
import ast | |
from shapely import wkt | |
import streamlit as st | |
import plotly.express as px | |
from streamlit.components.v1 import html as st_html | |
import fsspec | |
import json | |
import s3fs | |
from rasterio.io import MemoryFile | |
import datetime | |
## app contents | |
# set page title and layout | |
st.set_page_config( | |
page_title="GFM Explainability Demo", | |
layout="wide", | |
) | |
# create background image: read image and base64-encode it | |
with open("data/sx_darkened_fields_v2.jpg", "rb") as f: | |
b64 = base64.b64encode(f.read()).decode("utf-8") | |
bg_url = f"data:image/jpeg;base64,{b64}" | |
st.markdown( | |
f""" | |
<style> | |
.stApp {{ | |
background-image: url("{bg_url}"); | |
background-attachment: fixed; | |
background-size: cover; | |
background-repeat: no-repeat; | |
}} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
""" | |
<style> | |
.rainbow-text { | |
display: inline-block; | |
background: linear-gradient( | |
to right, | |
red, orange, yellow, lightblue, violet | |
); | |
background-clip: text; | |
-webkit-background-clip: text; | |
color: transparent !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
col1, col2, col3 = st.columns([1, 5, 1]) | |
with col2: | |
st.markdown( | |
"<h1 style='color:white;'>GFM Explainability Demo 🌎</h1>", | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
""" | |
<p style='color:white; font-size:16px;'> | |
This app extracts t-SNE of Embeddings from image chips using | |
<span class="rainbow-text"><b>Prithvi-EO-2.0 model</b></span>! | |
</p> | |
""", | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
""" | |
<style> | |
h4 { | |
color: rgb(204,156,172) !important; | |
text-decoration: non; | |
} | |
a.custom-link { | |
color: rgb(225,225,225) !important; | |
text-decoration: underline; | |
} | |
p { | |
color: rgb(225,225,225); | |
} | |
code { | |
color: #00C957 !important; | |
background-color: rgb(60, 60, 60) !important; | |
padding: 2px 4px; | |
border-radius: 3px; | |
} | |
</style> | |
<h4>Prithvi-EO-2.0</h4> | |
<p> | |
<code>Prithvi-EO-2.0</code> is the second-generation Geospatial Foundational Model developed by IBM, NASA, and the Jülich Supercomputing Centre. For more details, see the | |
<a | |
class="custom-link" | |
href="https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M" | |
target="_blank" | |
>NASA Hugging Face</a>. | |
</p> | |
<h4>Chips</h4> | |
<p> | |
The input chips for the pretrained Prithvi encoder are <code>224×224</code>, <code>6-band</code>, <code>4 time-step</code> Sentinel-2 imagery. Each chip contains a centered 64×64 patch representing a single pure land cover class, based on the Annual Land Use Land Cover (LULC) dataset. The dataset includes six categories: Water, Trees, Crops, Built Area, Bare Ground, and Rangeland. | |
Each time series generates an embedding map of shape [785, 1,024], where the first 784 vectors correspond to individual patches, and the 785th vector—the <code>CLS token</code>—represents the entire image. The second dimension (1,024) is the model’s embedding size, selected during pretraining to balance accuracy and computational efficiency. | |
</p> | |
<h4>t-SNE Transformation</h4> | |
<p> | |
To visualize the high-dimensional embeddings, a <code>t-SNE</code> transformation is applied to the 1,024-dimensional CLS tokens, reducing them to 2D while preserving relative distances between samples. Each sample is annotated with its corresponding land cover class. | |
</p> | |
""", | |
unsafe_allow_html=True | |
) | |
# st.markdown( | |
# """ | |
# ### Prithvi-EO-2.0 | |
# `Prithvi-EO-2.0` is the second generation Geospatial Foundational Model developed by IBM, NASA, and Jülich Supercomputing Centre. More details see [NASA Hugging Face]( https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M). | |
# ### Chips | |
# The chips utilized as inputs to a pretrained Prithvi encoder is `224x224`, `6-band`, `4 time step` Sentinel 2 imagery with a `64x64` patch of a pure land cover class in the center, determined by the Annual Land Use Land Cover (LULC) dataset with six categories—Water, Trees, Crops, Built area, Bare ground and Rangeland. The results for each time series are embedding maps of size [785, 1,024], in which the first 784 vectors of the first dimension are the embeddings for each patch, and the 785th vector is the embedding for the overall image, or `CLS token`. The second dimension of 1,024 is the model depth, and was set during pretraining to balance the tradeoffs of model complexity. | |
# ### t-SNE Transformation | |
# A `t-SNE` transformation is used to transform the 1,024 dimensions of the CLS token into 2 dimensions which preserve as much of the relative distance between samples as possible. The land cover class of each input sample is recorded and paired with the relevant input. | |
# """ | |
# ) | |
# read csv | |
chips_df = pd.read_csv("data/embeddings_df_v0.11_test.csv") | |
# set anonymous S3FileSystem to read files from public bucket | |
s3 = s3fs.S3FileSystem(anon=True) | |
def get_lat(geometry): | |
lat = wkt.loads(geometry).coords.xy[1][0] | |
return lat | |
def get_lon(geometry): | |
lon = wkt.loads(geometry).coords.xy[0][0] | |
return lon | |
## generate json | |
# title: plot title | |
# xaxis_title: x axis title | |
# yaxis_title: x axis title | |
config = { | |
"title" : "Visualization of EO-FM-Bench Embeddings", | |
"xaxis_title" : "t-SNE Dimension 1", | |
"yaxis_title" : "t-SNE Dimension 2", | |
} | |
# convert to json | |
title_js = json.dumps(config["title"]) | |
xaxis_js = json.dumps(config["xaxis_title"]) | |
yaxis_js = json.dumps(config["yaxis_title"]) | |
# set lc(str) for categorical data for plotting | |
chips_df["lc"] = chips_df["lc"].astype(str) | |
# add latitude and longitude | |
chips_df["latitude"] = chips_df["geometry"].apply(get_lat) | |
chips_df["longitude"] = chips_df["geometry"].apply(get_lon) | |
# color dictionary | |
color_dict = { | |
'1': '#2c41e6', # Water | |
'2': '#04541b', # Trees | |
'5': '#99e0ad', # Crops | |
'7': '#797b85', # Built area | |
'8': '#a68647', # Bare ground | |
'11': '#f7980a', # Rangeland | |
} | |
# land cover dictionary | |
land_cover = { | |
'1': 'Water', | |
'2': 'Trees', | |
'5': 'Crops', | |
'7': 'Built area', | |
'8': 'Bare ground', | |
'11': 'Rangeland' | |
} | |
# add the legend column | |
chips_df['Land Cover'] = chips_df['lc'].map(land_cover) | |
# color dictionary with label | |
color_dict_label = { | |
'Water': '#2c41e6', | |
'Trees': '#04541b', | |
'Crops': '#99e0ad', | |
'Built area': "#111112", | |
'Bare ground': '#a68647', | |
'Rangeland': '#f7980a' | |
} | |
# create dates Python list | |
chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval) | |
# set prefix | |
s3_url="https://gfm-bench.s3.amazonaws.com/thumbnails" | |
# create thumb_urls column | |
chips_df["thumb_urls"] = chips_df.apply( | |
lambda r: [ | |
f"{s3_url}/s2_{r.chip_id:06}_{date}.png" | |
for date in r.dates_list | |
], | |
axis=1 | |
) | |
# build a list of points dictionary | |
points = ( | |
chips_df | |
.rename(columns={ | |
"cls_dim1": "x", | |
"cls_dim2": "y", | |
"Land Cover": "category" | |
})[["x","y","chip_id", "latitude", "longitude","category","dates_list"]] | |
.assign( | |
id = chips_df["chip_id"], | |
lat = chips_df["latitude"], | |
lon = chips_df["longitude"], | |
color=chips_df["Land Cover"].map(color_dict_label), | |
thumbs = chips_df["thumb_urls"]) | |
.to_dict(orient="records") | |
) | |
# convert dictionary to json | |
points_json = json.dumps(points) | |
## build up plot and image container html | |
plot_html = f""" | |
<script src="https://cdn.plot.ly/plotly-2.18.1.min.js"></script> | |
<style> | |
html, body {{ margin:0; padding:0; height:100%; }} | |
#container {{ display:flex; width:100%; height:100%; }} | |
#map-plot, #scatter-plot {{ flex:1; padding:4px; box-sizing:border-box;margin-right:16px;}} | |
#scatter-plot {{ | |
margin-right:6px; | |
}} | |
#image-container {{ | |
display: grid; | |
grid-template-columns: repeat(2, 1fr); | |
flex: 0 0 400px; | |
height: 100%; | |
box-sizing: border-box; | |
padding: 4px; | |
grid-auto-rows: auto; | |
gap: 4px; | |
overflow: hidden; | |
}} | |
#image-container img {{ width:100%; height:auto; max-height:200px; }} | |
</style> | |
<div id="container"> | |
<div id="map-plot"></div> | |
<div id="scatter-plot"></div> | |
<div id="image-container"></div> | |
</div> | |
<script> | |
const points = {points_json}; | |
const cats = Array.from(new Set(points.map(p=>p.category))); | |
// 1) map traces | |
const mapTraces = cats.map(cat => {{ | |
const pts = points.filter(p=>p.category===cat); | |
return {{ | |
type: 'scattermapbox', | |
mode: 'markers', | |
name: cat, | |
x: pts.map(p=>p.x), | |
y: pts.map(p=>p.y), | |
id: pts.map(p=>p.id), | |
lat: pts.map(p=>p.lat), | |
lon: pts.map(p=>p.lon), | |
customdata:pts.map(p=>[ | |
p.id, | |
p.x, | |
p.y | |
]), | |
marker: {{ | |
color: pts.map(p=>p.color), | |
symbol: 'circle', | |
size: 8, | |
line: {{ color:'red', width:2 }} | |
}}, | |
hovertemplate: | |
`<b>ID:</b> %{{customdata[0]}}<br>` + | |
`<b>x:</b> %{{customdata[1]:.4f}}<br>` + | |
`<b>y:</b> %{{customdata[2]:.4f}}<br>` + | |
`<b>lat:</b> %{{lat:.2f}}<br>` + | |
`<b>lon:</b> %{{lon:.2f}}<extra></extra>`, | |
selectedpoints: [], | |
selected: {{ marker: {{ color:'lightblue', size: 10 }} }}, | |
unselected: {{ marker: {{ opacity:0.2 }} }} | |
}}; | |
}}); | |
const mapLayout = {{ | |
mapbox: {{ | |
style: 'mapbox://styles/mapbox/satellite-streets-v11', | |
center: {{ lon: 0, lat: 0 }}, | |
zoom: 0, | |
accesstoken: 'pk.eyJ1IjoiY2xhcmtjZ2EteWF5YW8iLCJhIjoiY21jdDl0MDZoMDM3cjJscHBmcWpjbnhkaiJ9.YkEYejNsY5-r3DtESJ46kQ' | |
}}, | |
clickmode: 'event+select', | |
margin: {{ l:0,r:0,t:0,b:0 }}, | |
}}; | |
Plotly.newPlot('map-plot', mapTraces, mapLayout, {{responsive:true}}); | |
// 2) scatter traces: build one trace per category | |
const scatterTraces = cats.map(cat => {{ | |
const pts = points.filter(p=>p.category===cat); | |
return {{ | |
x: pts.map(p=>p.x), | |
y: pts.map(p=>p.y), | |
id: pts.map(p=>p.id), | |
lat: pts.map(p=>p.lat), | |
lon: pts.map(p=>p.lon), | |
customdata:pts.map(p=>[ | |
p.id, | |
p.lat, | |
p.lon | |
]), | |
mode: 'markers', | |
type: 'scatter', | |
name: cat, | |
marker: {{ | |
color: pts.map(p=>p.color), | |
size: 5, | |
line: {{ color:'black', width:1 }} | |
}}, | |
hovertemplate: | |
`<b>ID:</b> %{{customdata[0]}}<br>` + | |
`<b>x:</b> %{{x:.2f}}<br>` + | |
`<b>y:</b> %{{y:.2f}}<br>` + | |
`<b>lat:</b> %{{customdata[1]:.4f}}<br>` + | |
`<b>lon:</b> %{{customdata[2]:.4f}}<extra></extra>`, | |
selectedpoints: [], | |
selected: {{ marker: {{ color:'lightblue', size: 10}} }}, | |
unselected: {{ marker: {{ opacity:0.2 }} }} | |
}}; | |
}}); | |
const scatterLayout = {{ | |
hovermode: 'closest', | |
title: {title_js}, | |
xaxis: {{ | |
title: {xaxis_js}, | |
range: [-110, 110], | |
showgrid: true, | |
gridcolor: 'rgb(255,255,255)', | |
gridwidth: 1, | |
showline: false, | |
zeroline: false, | |
showticklabels: true, | |
ticks: 'outside', | |
tickcolor: 'rgb(127,127,127)' | |
}}, | |
yaxis: {{ | |
title: {yaxis_js}, | |
range: [-110, 110], | |
showgrid: true, | |
gridcolor: 'rgb(255,255,255)', | |
gridwidth: 1, | |
showline: false, | |
zeroline: false, | |
showticklabels: true, | |
ticks: 'outside', | |
tickcolor: 'rgb(127,127,127)' | |
}}, | |
paper_bgcolor: 'rgb(255,255,255)', | |
plot_bgcolor: 'rgb(234,234,242)', | |
autosize: true, | |
margin: {{ l:40, r:40, t:40, b:40 }}, | |
clickmode:'event+select', | |
legend: {{ font:{{ size:12 }}, x:1.01, y:0.5 }} | |
}}; | |
Plotly.newPlot('scatter-plot', scatterTraces, scatterLayout, {{responsive:true}}); | |
// 3) click handler | |
function onPointClick(evt) {{ | |
const pt = evt.points[0]; | |
const idx = pt.pointIndex; | |
const traceNo = pt.curveNumber; | |
// highlight in both | |
Plotly.restyle('map-plot', {{ selectedpoints:[[idx]] }}, [traceNo]); | |
Plotly.restyle('scatter-plot', {{ selectedpoints:[[idx]] }}, [traceNo]); | |
// pull id from array | |
const clickedId = Array.isArray(pt.customdata) | |
// if | |
? pt.customdata[0] | |
//else | |
: pt.customdata; | |
// find the record | |
const record = points.find(p=>p.id===clickedId); | |
if(!record) return; | |
// show thumbnails | |
const thumbs = record.thumbs; | |
const dates = record.dates_list; | |
const cont = document.getElementById('image-container'); | |
cont.innerHTML = ''; | |
thumbs.forEach((url,i) => {{ | |
if (!url) return; | |
const card = document.createElement('div'); | |
card.style.textAlign = 'center'; | |
card.style.marginBottom = '8px'; | |
const img = document.createElement('img'); | |
img.src = url; | |
img.style.width = '100%'; | |
img.style.maxHeight = '180px'; | |
const lbl = document.createElement('p'); | |
lbl.textContent = dates[i]; | |
lbl.style.color = 'white'; | |
lbl.style.margin = '4px 0 0'; | |
lbl.style.fontSize = '0.9em'; | |
card.appendChild(img); | |
card.appendChild(lbl); | |
cont.appendChild(card); | |
}}); | |
}} | |
document.getElementById('map-plot').on('plotly_click', onPointClick); | |
document.getElementById('scatter-plot').on('plotly_click', onPointClick); | |
</script> | |
""" | |
# embed into Streamlit | |
st_html(plot_html, height=500, width=2000, scrolling=False) | |
# build up footer html | |
year = datetime.datetime.now().year | |
footer_html = f""" | |
<style> | |
#footer {{ | |
margin-top: 1rem; | |
color: rgb(204,156,172); | |
}} | |
#footer a {{ | |
color: rgb(204,156,172); | |
text-decoration: underline; | |
}} | |
</style> | |
<div id="footer"> | |
Background image credit: <b>Sitian Xiong</b>; image source: <a href="https://visibleearth.nasa.gov/images/152732/golden-fields-in-romania/152734l"><b>NASA Earth Observatory</b></a><br> | |
<b>Copyright © {year} - Clark Center for Geospatial Analytics</b> | |
</div> | |
""" | |
# embed into Streamlit | |
col1, col2, col3 = st.columns([1, 5, 1]) | |
with col2: | |
st.markdown(footer_html, unsafe_allow_html=True) |