from fastapi import FastAPI, File, UploadFile import uvicorn from typing import List from io import BytesIO import numpy as np import rasterio from pydantic import BaseModel import torch from huggingface_hub import hf_hub_download from mmcv import Config from mmseg.apis import init_segmentor import gradio as gr from functools import partial import time import os # Initialize the FastAPI app app = FastAPI() # Load the model and config config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", filename="multi_temporal_crop_classification_Prithvi_100M.py", token=os.environ.get("token")) ckpt = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", filename='multi_temporal_crop_classification_Prithvi_100M.pth', token=os.environ.get("token")) config = Config.fromfile(config_path) config.model.backbone.pretrained = None model = init_segmentor(config, ckpt, device='cpu') # Use the test pipeline directly custom_test_pipeline = model.cfg.data.test.pipeline # Define the input/output model for FastAPI class PredictionOutput(BaseModel): t1: List[float] t2: List[float] t3: List[float] prediction: List[float] # Define the inference function def inference_on_file(file_path, model, custom_test_pipeline): with rasterio.open(file_path) as src: img = src.read() # Apply preprocessing using the custom pipeline processed_img = apply_pipeline(custom_test_pipeline, img) # Run inference output = model.inference(processed_img) # Post-process the output to get the RGB and prediction images rgb1 = postprocess_output(output[0]) rgb2 = postprocess_output(output[1]) rgb3 = postprocess_output(output[2]) return rgb1, rgb2, rgb3, output def apply_pipeline(pipeline, img): # Implement your custom pipeline processing here # This could include normalization, resizing, etc. return img def postprocess_output(output): # Convert the model's output into an RGB image or other formats as needed return output @app.post("/predict/", response_model=PredictionOutput) async def predict(file: UploadFile = File(...)): # Read the uploaded file target_image = BytesIO(await file.read()) # Save the file temporarily if needed with open("temp_image.tif", "wb") as f: f.write(target_image.getvalue()) # Run the prediction rgb1, rgb2, rgb3, output = inference_on_file("temp_image.tif", model, custom_test_pipeline) # Return the results return { "t1": rgb1.tolist(), "t2": rgb2.tolist(), "t3": rgb3.tolist(), "prediction": output.tolist() } # Optional: Serve the Gradio interface (if you still want to use it with FastAPI) def run_gradio_interface(): func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline) with gr.Blocks() as demo: gr.Markdown(value='# Prithvi multi temporal crop classification') gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More details can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order.''') with gr.Row(): with gr.Column(): inp = gr.File() btn = gr.Button("Submit") with gr.Row(): inp1 = gr.Image(image_mode='RGB', scale=10, label='T1') inp2 = gr.Image(image_mode='RGB', scale=10, label='T2') inp3 = gr.Image(image_mode='RGB', scale=10, label='T3') out = gr.Image(image_mode='RGB', scale=10, label='Model prediction') btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out]) with gr.Row(): with gr.Column(): gr.Examples(examples=["chip_102_345_merged.tif", "chip_104_104_merged.tif", "chip_109_421_merged.tif"], inputs=inp, outputs=[inp1, inp2, inp3, out], preprocess=preprocess_example, fn=func, cache_examples=True) with gr.Column(): gr.Markdown(value='### Model prediction legend') gr.Image(value='Legend.png', image_mode='RGB', show_label=False) demo.launch() if __name__ == "__main__": run_gradio_interface() uvicorn.run(app, host="0.0.0.0", port=8000)