Spaces:
Running
Running
File size: 7,244 Bytes
9665c2c a92fdb9 9665c2c a92fdb9 9665c2c a92fdb9 9665c2c a92fdb9 3352b5d a92fdb9 9665c2c 3352b5d 9665c2c 9aff230 9665c2c 9aff230 9665c2c 9aff230 9665c2c b5a5180 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Optional, Tuple
import torch
import numpy as np
from . import logger
from .evaluation.run import resolve_checkpoint_path, pretrained_models
from .models.orienternet import OrienterNet
from .models.voting import fuse_gps, argmax_xyr
from .data.image import resize_image, pad_image, rectify_image
from .osm.raster import Canvas
from .utils.wrappers import Camera
from .utils.io import read_image
from .utils.geo import BoundaryBox, Projection
from .utils.exif import EXIF
try:
from geopy.geocoders import Nominatim
geolocator = Nominatim(user_agent="orienternet")
except ImportError:
geolocator = None
try:
from gradio_client import Client
calibrator = Client("https://jinlinyi-perspectivefields.hf.space/")
except (ImportError, ValueError):
calibrator = None
def image_calibration(image_path):
logger.info("Calling the PerspectiveFields calibrator, this may take some time.")
result = calibrator.predict(
image_path, "NEW:Paramnet-360Cities-edina-centered", api_name="/predict"
)
result = dict(r.rsplit(" ", 1) for r in result[1].split("\n"))
roll_pitch = float(result["roll"]), float(result["pitch"])
return roll_pitch, float(result["vertical fov"])
def camera_from_exif(exif: EXIF, fov: Optional[float] = None) -> Camera:
w, h = image_size = exif.extract_image_size()
_, f_ratio = exif.extract_focal()
if f_ratio == 0:
if fov is not None:
# This is the vertical FoV.
f = h / 2 / np.tan(np.deg2rad(fov) / 2)
else:
return None
else:
f = f_ratio * max(image_size)
return Camera.from_dict(
dict(
model="SIMPLE_PINHOLE",
width=w,
height=h,
params=[f, w / 2 + 0.5, h / 2 + 0.5],
)
)
def read_input_image(
image_path: str,
prior_latlon: Optional[Tuple[float, float]] = None,
prior_address: Optional[str] = None,
fov: Optional[float] = None,
tile_size_meters: int = 64,
):
image = read_image(image_path)
with open(image_path, "rb") as fid:
exif = EXIF(fid, lambda: image.shape[:2])
latlon = None
if prior_latlon is not None:
latlon = prior_latlon
logger.info("Using prior latlon %s.", prior_latlon)
if prior_address is not None:
if geolocator is None:
raise ValueError("geocoding unavailable, install geopy.")
location = geolocator.geocode(prior_address)
if location is None:
logger.info("Could not find any location for address '%s.'", prior_address)
else:
logger.info("Using prior address '%s'", location.address)
latlon = (location.latitude, location.longitude)
if latlon is None:
geo = exif.extract_geo()
if geo:
alt = geo.get("altitude", 0) # read if available
latlon = (geo["latitude"], geo["longitude"], alt)
logger.info("Using prior location from EXIF.")
else:
logger.info("Could not find any prior location in the image EXIF metadata.")
if latlon is None:
raise ValueError(
"No location prior given or found in the image EXIF metadata: "
"maybe provide the name of a street, building or neighborhood?"
)
latlon = np.array(latlon)
roll_pitch = None
if calibrator is not None:
roll_pitch, fov = image_calibration(image_path)
else:
logger.info("Could not call PerspectiveFields, maybe install gradio_client?")
if roll_pitch is not None:
logger.info("Using (roll, pitch) %s.", roll_pitch)
camera = camera_from_exif(exif, fov)
if camera is None:
raise ValueError(
"No camera intrinsics found in the EXIF, provide an FoV guess."
)
proj = Projection(*latlon)
center = proj.project(latlon)
bbox = BoundaryBox(center, center) + tile_size_meters
return image, camera, roll_pitch, proj, bbox, latlon
class Demo:
def __init__(
self,
experiment_or_path: Optional[str] = "OrienterNet_MGL",
device=None,
**kwargs
):
if experiment_or_path in pretrained_models:
experiment_or_path, _ = pretrained_models[experiment_or_path]
path = resolve_checkpoint_path(experiment_or_path)
ckpt = torch.load(path, map_location=(lambda storage, loc: storage))
config = ckpt["hyper_parameters"]
config.model.update(kwargs)
config.model.image_encoder.backbone.pretrained = False
model = OrienterNet(config.model).eval()
state = {k[len("model.") :]: v for k, v in ckpt["state_dict"].items()}
model.load_state_dict(state, strict=True)
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
self.model = model
self.config = config
self.device = device
def prepare_data(
self,
image: np.ndarray,
camera: Camera,
canvas: Canvas,
roll_pitch: Optional[Tuple[float]] = None,
):
assert image.shape[:2][::-1] == tuple(camera.size.tolist())
target_focal_length = self.config.data.resize_image / 2
factor = target_focal_length / camera.f
size = (camera.size * factor).round().int()
image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255)
valid = None
if roll_pitch is not None:
roll, pitch = roll_pitch
image, valid = rectify_image(
image,
camera.float(),
roll=-roll,
pitch=-pitch,
)
image, _, camera, *maybe_valid = resize_image(
image, size.tolist(), camera=camera, valid=valid
)
valid = None if valid is None else maybe_valid
max_stride = max(self.model.image_encoder.layer_strides)
size = (torch.ceil(size / max_stride) * max_stride).int()
image, valid, camera = pad_image(
image, size.tolist(), camera, crop_and_center=True
)
return dict(
image=image,
map=torch.from_numpy(canvas.raster).long(),
camera=camera.float(),
valid=valid,
)
def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs):
data = self.prepare_data(image, camera, canvas, **kwargs)
data_ = {k: v.to(self.device)[None] for k, v in data.items()}
with torch.no_grad():
pred = self.model(data_)
xy_gps = canvas.bbox.center
uv_gps = torch.from_numpy(canvas.to_uv(xy_gps))
lp_xyr = pred["log_probs"].squeeze(0)
tile_size = canvas.bbox.size.min() / 2
sigma = tile_size - 20 # 20 meters margin
lp_xyr = fuse_gps(
lp_xyr,
uv_gps.to(lp_xyr),
self.config.model.pixel_per_meter,
sigma=sigma,
)
xyr = argmax_xyr(lp_xyr).cpu()
prob = lp_xyr.exp().cpu()
neural_map = pred["map"]["map_features"][0].squeeze(0).cpu()
return xyr[:2], xyr[2], prob, neural_map, data["image"]
|