|
from fastapi import FastAPI, File, UploadFile |
|
import uvicorn |
|
from typing import List |
|
from io import BytesIO |
|
import numpy as np |
|
import rasterio |
|
from pydantic import BaseModel |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from mmcv import Config |
|
from mmseg.apis import init_segmentor |
|
import gradio as gr |
|
from functools import partial |
|
import time |
|
import os |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", |
|
filename="multi_temporal_crop_classification_Prithvi_100M.py", |
|
token=os.environ.get("token")) |
|
ckpt = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", |
|
filename='multi_temporal_crop_classification_Prithvi_100M.pth', |
|
token=os.environ.get("token")) |
|
|
|
config = Config.fromfile(config_path) |
|
config.model.backbone.pretrained = None |
|
model = init_segmentor(config, ckpt, device='cpu') |
|
|
|
|
|
custom_test_pipeline = process_test_pipeline(model.cfg.data.test.pipeline, None) |
|
|
|
|
|
class PredictionOutput(BaseModel): |
|
t1: List[float] |
|
t2: List[float] |
|
t3: List[float] |
|
prediction: List[float] |
|
|
|
@app.post("/predict/", response_model=PredictionOutput) |
|
async def predict(file: UploadFile = File(...)): |
|
|
|
target_image = BytesIO(await file.read()) |
|
|
|
|
|
with open("temp_image.tif", "wb") as f: |
|
f.write(target_image.getvalue()) |
|
|
|
|
|
rgb1, rgb2, rgb3, output = inference_on_file("temp_image.tif", model, custom_test_pipeline) |
|
|
|
|
|
return { |
|
"t1": rgb1.tolist(), |
|
"t2": rgb2.tolist(), |
|
"t3": rgb3.tolist(), |
|
"prediction": output.tolist() |
|
} |
|
|
|
|
|
def run_gradio_interface(): |
|
func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(value='# Prithvi multi temporal crop classification') |
|
gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More details can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n |
|
The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order.''') |
|
with gr.Row(): |
|
with gr.Column(): |
|
inp = gr.File() |
|
btn = gr.Button("Submit") |
|
|
|
with gr.Row(): |
|
inp1 = gr.Image(image_mode='RGB', scale=10, label='T1') |
|
inp2 = gr.Image(image_mode='RGB', scale=10, label='T2') |
|
inp3 = gr.Image(image_mode='RGB', scale=10, label='T3') |
|
out = gr.Image(image_mode='RGB', scale=10, label='Model prediction') |
|
|
|
btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out]) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Examples(examples=["chip_102_345_merged.tif", |
|
"chip_104_104_merged.tif", |
|
"chip_109_421_merged.tif"], |
|
inputs=inp, |
|
outputs=[inp1, inp2, inp3, out], |
|
preprocess=preprocess_example, |
|
fn=func, |
|
cache_examples=True) |
|
with gr.Column(): |
|
gr.Markdown(value='### Model prediction legend') |
|
gr.Image(value='Legend.png', image_mode='RGB', show_label=False) |
|
|
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
run_gradio_interface() |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|