Spaces:
Running
Running
import pandas as pd | |
import plotly.express as px | |
import io | |
import numpy as np | |
import rasterio | |
import base64 | |
from PIL import Image | |
import ast | |
from shapely import wkt | |
import streamlit as st | |
import plotly.express as px | |
from streamlit.components.v1 import html as st_html | |
import fsspec | |
import json | |
import s3fs | |
from rasterio.io import MemoryFile | |
import datetime | |
## app contents | |
# set page title and layout | |
st.set_page_config( | |
page_title="GFM Explainability Demo", | |
layout="wide", | |
) | |
# create background image: read image and base64-encode it | |
with open("data/sx_darkened_fields_v2.jpg", "rb") as f: | |
b64 = base64.b64encode(f.read()).decode("utf-8") | |
bg_url = f"data:image/jpeg;base64,{b64}" | |
st.markdown( | |
f""" | |
<style> | |
.stApp {{ | |
background-image: url("{bg_url}"); | |
background-attachment: fixed; | |
background-size: cover; | |
background-repeat: no-repeat; | |
}} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
""" | |
<style> | |
.rainbow-text { | |
display: inline-block; | |
background: linear-gradient( | |
to right, | |
red, orange, yellow, green, lightblue, blue, violet, pink | |
); | |
background-clip: text; | |
-webkit-background-clip: text; | |
color: transparent !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
col1, col2, col3 = st.columns([1, 5, 1]) | |
with col2: | |
st.markdown( | |
"<h1 style='color:white;'>GFM Explainability Demo π</h1>", | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
""" | |
<p style='color:white; font-size:16px;'> | |
This app extracts t-SNE of Embeddings from image chips using | |
<span class="rainbow-text"><b>Prithvi-EO-2.0 model</b></span>! | |
</p> | |
""", | |
unsafe_allow_html=True | |
) | |
# read csv | |
chips_df = pd.read_csv("data/embeddings_df_v0.11_test.csv") | |
# set anonymous S3FileSystem to read files from public bucket | |
s3 = s3fs.S3FileSystem(anon=True) | |
## helper function | |
def gen_chip_urls(row, s3_prefix): | |
''' | |
Generate S3 urls for chips | |
:param row: dictionary with chip_id and dates | |
:param s3_prefix: S3 url prefix | |
:return s3_urls: a list of urls | |
''' | |
s3_urls = [] | |
dates = ast.literal_eval(row["dates"]) | |
for date in dates: | |
filename = f"s2_{row['chip_id']:06}_{date}.tif" | |
s3_url = f"{s3_prefix}/{filename}" | |
s3_urls.append(s3_url) | |
return s3_urls | |
def mask_nodata(band, nodata_values=(-999,)): | |
''' | |
Mask nodata to nan | |
:param band | |
:param nodata_values:nodata values in chips is -999 | |
:return band | |
''' | |
band = band.astype(float) | |
for val in nodata_values: | |
band[band == val] = np.nan | |
return band | |
def normalize(band): | |
''' | |
Normalize a band to 0-1 range(float) | |
:param band (ndarray) | |
return normalize band (ndarray); when max equals min, returns zeros. | |
''' | |
if np.nanmean(band) >= 4000: | |
band = band / 6000 | |
else: | |
band = band / 4000 | |
band = np.clip(band, None, 1) | |
return band | |
def create_thumbnail(url): | |
''' | |
Read S3 file into memory, using rasterio to create a png thumbnail then encode as a base64 string url | |
:param url: chip url | |
:return a base64-encoded png string, returns an empty string when an error occurs | |
''' | |
try: | |
# read raw bytes from s3 file | |
with s3.open(url, "rb") as f: | |
data = f.read() | |
# wrap the raw bytes into an memory file | |
with MemoryFile(data) as memfile: | |
# read memory file with rasterio | |
with memfile.open() as src: | |
# mask nodata to have correct calculate normalization | |
# band1->blue, band2->green, band3->red | |
blue = src.read(1).astype(float) | |
green = src.read(2).astype(float) | |
red = src.read(3).astype(float) | |
blue = normalize(mask_nodata(blue)) | |
green = normalize(mask_nodata(green)) | |
red = normalize(mask_nodata(red)) | |
# stack in RGB | |
rgb = np.dstack((red, green, blue)) | |
# convert float(0-1) to uint8 (0-255) | |
rgb_8bit = (rgb * 255).astype(np.uint8) | |
# convert to png in memory | |
pil_img = Image.fromarray(rgb_8bit) | |
buf = io.BytesIO() | |
pil_img.save(buf, format='PNG') | |
# encoded into base64 HTML format | |
encoded = base64.b64encode(buf.getvalue()).decode('utf-8') | |
return f"data:image/png;base64,{encoded}" | |
except Exception as e: | |
# return an empty string for Exception | |
return "" | |
def get_lat(geometry): | |
lat = wkt.loads(geometry).coords.xy[1][0] | |
return lat | |
def get_lon(geometry): | |
lon = wkt.loads(geometry).coords.xy[0][0] | |
return lon | |
## generate json | |
# title: plot title | |
# xaxis_title: x axis title | |
# yaxis_title: x axis title | |
config = { | |
"title" : "t-SNE Visualization of EO-FM-Bench Embeddings for Prithvi-EO-2.0", | |
"xaxis_title" : "t-SNE Dimension 1", | |
"yaxis_title" : "t-SNE Dimension 2", | |
} | |
# convert to json | |
title_js = json.dumps(config["title"]) | |
xaxis_js = json.dumps(config["xaxis_title"]) | |
yaxis_js = json.dumps(config["yaxis_title"]) | |
# set prefix | |
s3_prefix="s3://gfm-bench" | |
# generate S3 file URLs | |
chips_df["urls"] = chips_df.apply(lambda row: gen_chip_urls(row, s3_prefix), axis=1) | |
# set lc(str) for categorical data for plotting | |
chips_df["lc"] = chips_df["lc"].astype(str) | |
# add latitude and longitude | |
chips_df["latitude"] = chips_df["geometry"].apply(get_lat) | |
chips_df["longitude"] = chips_df["geometry"].apply(get_lon) | |
# color dictionary | |
color_dict = { | |
'1': '#2c41e6', # Water | |
'2': '#04541b', # Trees | |
'5': '#99e0ad', # Crops | |
'7': '#797b85', # Built area | |
'8': '#a68647', # Bare ground | |
'11': '#f7980a', # Rangeland | |
} | |
# land cover dictionary | |
land_cover = { | |
'1': 'Water', | |
'2': 'Trees', | |
'5': 'Crops', | |
'7': 'Built area', | |
'8': 'Bare ground', | |
'11': 'Rangeland' | |
} | |
# add the legend column | |
chips_df['Land Cover'] = chips_df['lc'].map(land_cover) | |
# color dictionary with label | |
color_dict_label = { | |
'Water': '#2c41e6', | |
'Trees': '#04541b', | |
'Crops': '#99e0ad', | |
'Built area': '#797b85', | |
'Bare ground': '#a68647', | |
'Rangeland': '#f7980a' | |
} | |
# create thumbnail | |
chips_df["thumbs"] = chips_df["urls"].apply( | |
lambda urls: [create_thumbnail(p) for p in urls] | |
) | |
# create dates Python list | |
chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval) | |
# build a list of points dictionary | |
points = ( | |
chips_df | |
.rename(columns={ | |
"cls_dim1": "x", | |
"cls_dim2": "y", | |
"Land Cover": "category" | |
})[["x","y","category","thumbs","dates_list"]] | |
.assign(color=chips_df["Land Cover"].map(color_dict_label)) | |
.to_dict(orient="records") | |
) | |
# convert dictionary to json | |
points_json = json.dumps(points) | |
## build up plot and image container html | |
plot_html = f""" | |
<script src="https://cdn.plot.ly/plotly-3.0.1.min.js"></script> | |
<style> | |
html, body {{ | |
margin:0; padding:0; height:100%; | |
}} | |
#container {{ | |
display: flex; | |
width: 100%; | |
height: 100%; | |
}} | |
#scatter-plot {{ | |
flex: 1 1 auto; | |
min-width: 0; | |
height: 100%; | |
}} | |
#image-container {{ | |
display: grid; | |
grid-template-columns: repeat(2, 1fr); | |
flex: 0 0 400px; | |
height: 100%; | |
box-sizing: border-box; | |
padding: 4px; | |
grid-auto-rows: auto; | |
gap: 4px; | |
overflow: hidden; | |
}} | |
#image-container img {{ | |
width: 100%; | |
height: auto; | |
max-height: 200px; | |
}} | |
</style> | |
<div id="container"> | |
<div id="scatter-plot"></div> | |
<div id="image-container"></div> | |
</div> | |
<script> | |
const points = {points_json}; | |
const cats = Array.from(new Set(points.map(p=>p.category))); | |
// build one trace per category | |
const traces = cats.map(cat => {{ | |
const pts = points.filter(p=>p.category===cat); | |
return {{ | |
x: pts.map(p=>p.x), | |
y: pts.map(p=>p.y), | |
customdata:pts.map(p=>p.thumbs), | |
mode: 'markers', | |
type: 'scatter', | |
name: cat, | |
marker: {{ color: pts.map(p=>p.color), size:5 }} | |
}}; | |
}}); | |
// plotly layout | |
const layout = {{ | |
paper_bgcolor: "rgb(255,255,255)", | |
plot_bgcolor: "rgb(234, 234, 242)", | |
title: {title_js}, | |
xaxis: {{ title: {xaxis_js}, | |
range:[-110,110], | |
gridcolor: "rgb(255,255,255)", | |
showgrid: true, | |
showline: false, | |
showticklabels: true, | |
tickcolor: "rgb(127,127,127)", | |
ticks: "outside", | |
zeroline: false, | |
gridwidth: 1}}, | |
yaxis: {{ title: {yaxis_js}, | |
range:[-110,110], | |
gridcolor: "rgb(255,255,255)", | |
showgrid: true, | |
showline: false, | |
showticklabels: true, | |
tickcolor: "rgb(127,127,127)", | |
ticks: "outside", | |
zeroline: false, | |
gridwidth: 1}}, | |
autosize: true, | |
margin: {{ l:40, r:40, t:40, b:40 }}, | |
clickmode:'event+select', | |
legend: {{ font:{{ size:12 }}, x:1.01, y:0.5 }} | |
}}; | |
// select the scatter-plot div to render the chart into | |
const gd = document.getElementById('scatter-plot'); | |
// click event | |
Plotly.newPlot(gd, traces, layout, {{ responsive: true }}).then(() => {{ | |
gd.on('plotly_click', evt => {{ | |
// grab thumbs and dates through point index | |
const idx = evt.points[0].pointIndex; | |
const thumbs = points[idx].thumbs; | |
const dates = points[idx].dates_list; | |
// grab image container and clear out old thumbs | |
const container = document.getElementById('image-container'); | |
container.innerHTML = ''; | |
// append each thumbnail and date | |
thumbs.forEach((url, i) => {{ | |
if (url) {{ | |
// create card to bundle image and label content | |
const card = document.createElement('div'); | |
card.style.textAlign = 'center'; | |
card.style.marginBottom = '8px'; | |
//image | |
const img = document.createElement('img'); | |
img.src = url; | |
img.style.width = '100%'; | |
img.style.maxHeight = '180px'; | |
//label | |
const label = document.createElement('p'); | |
label.textContent = dates[i]; | |
label.style.color = 'white'; | |
label.style.margin = '4px 0 0 0'; | |
label.style.fontSize = '0.9em'; | |
//append | |
card.appendChild(img); | |
card.appendChild(label); | |
container.appendChild(card); | |
}} | |
}}); | |
}}); | |
}}); | |
</script> | |
""" | |
# build up footer html | |
year = datetime.datetime.now().year | |
footer_html = f""" | |
<style> | |
#footer {{ | |
margin-top: 1rem; | |
color: rgb(204,156,172); | |
}} | |
#footer a {{ | |
color: rgb(204,156,172); | |
text-decoration: underline; | |
}} | |
</style> | |
<div id="footer"> | |
Background image credit: <b>Sitian Xiong</b>; image source: <a href="https://visibleearth.nasa.gov/images/152732/golden-fields-in-romania/152734l"><b>NASA Earth Observatory</b></a><br> | |
<b>Copyright © {year} - Clark Center for Geospatial Analytics</b> | |
</div> | |
""" | |
# embed into Streamlit | |
col1, col2, col3 = st.columns([1, 5, 1]) | |
with col2: | |
st_html(plot_html, height=500, width=1000, scrolling=True) | |
st.markdown(footer_html, unsafe_allow_html=True) |