I am working with NetCDF data representing water bodies, and I need to display this data on a map using OpenLayers. The land areas in the dataset are represented by NaNs, which I need to preserve to maintain transparency for these areas.
However, I am encountering issues with interpolation where new NaNs are introduced at the borders (which results in tiles that don´t join together). Additionally, I need to ensure that the interpolated data matches the requested tile dimensions (typically 256×256 pixels) for OpenLayers.
Here is the code written in Python:
from fastapi import APIRouter, Query, HTTPException
from fastapi.responses import StreamingResponse
import scipy.interpolate
import xarray as xr
import dask.array as da
import matplotlib.pyplot as plt
from PIL import Image
import io
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from io import BytesIO
from pyproj import Proj, transform
import numpy as np
from scipy.interpolate import RegularGridInterpolator, griddata
from functools import lru_cache
from scipy.ndimage import map_coordinates
router = APIRouter()
# Fixed color scale
COLOR_SCALE_MIN = 0
COLOR_SCALE_MAX = 1
# Cache for interpolated data
INTERPOLATED_DATA_CACHE = {}
def extract_bbox(bbox_str: str) -> tuple:
"""Extract bounding box coordinates from a string"""
return tuple(map(float, bbox_str.split(",")))
def transform_coordinates(proj_in, proj_out, lon_min, lat_min, lon_max, lat_max):
"""Transform bounding box coordinates to match dataset's coordinate system"""
x_min, y_min = transform(proj_in, proj_out, lon_min, lat_min)
x_max, y_max = transform(proj_in, proj_out, lon_max, lat_max)
return x_min, y_min, x_max, y_max
def normalize_data(data: np.ndarray) -> np.ndarray:
"""Normalize data values to fit within a fixed range"""
return (data - COLOR_SCALE_MIN) / (COLOR_SCALE_MAX - COLOR_SCALE_MIN)
@lru_cache(maxsize=128)
def get_interpolated_data(filepath, var_name, bbox, width, height, crs, padding=1):
"""Extract and interpolate data from a NetCDF file"""
lon_min, lat_min, lon_max, lat_max = bbox
# Define projections (example UTM zone 33, adjust as needed)
x_min, y_min, x_max, y_max = transform_coordinates(
crs, "EPSG:32631", *bbox)
ds = xr.open_dataset(filepath)
data = ds[var_name]
# Select the first wavelength layer
subset = data.isel(wl=0).sel(
x=slice(x_min, x_max),
y=slice(y_max, y_min)
)
data_interp = subset_filled.interp(
x=np.linspace(x_min, x_max, width),
y=np.linspace(y_min, y_max, height),
method="nearest",
)
return data_interp
def create_color_mapped_image(data, width, height, cmap_name='plasma'):
"""Create a color-mapped image from normalized data"""
if data.size == 0:
print("Data is empty")
# Return a transparent image of default size
return Image.new('RGBA', (256, 256), (255, 255, 255, 0))
norm = Normalize(vmin=0, vmax=1)
# print(f"Normalized data min: {norm.vmin}, max: {norm.vmax}")
cmap = plt.get_cmap(cmap_name)
mappable = ScalarMappable(norm=norm, cmap=cmap)
mapped_data = (mappable.to_rgba(data, bytes=True)).astype(np.uint8)
mapped_data = np.flipud(mapped_data)
image = Image.fromarray(mapped_data, mode="RGBA")
return image
@router.get("/wms")
async def get_wms(
SERVICE: str = Query(...),
REQUEST: str = Query(...),
LAYERS: str = Query(...),
BBOX: str = Query(...),
WIDTH: int = Query(...),
HEIGHT: int = Query(...),
CRS: str = Query(...),
FORMAT: str = Query(...),
):
if SERVICE.lower() != "wms" or REQUEST.lower() != "getmap":
raise HTTPException(status_code=400, detail="Invalid WMS request")
try:
bbox = extract_bbox(BBOX)
netcdf_file = "./data/data.nc"
var_name = LAYERS.split(",")[0]
try:
interpolated_data = get_interpolated_data(
netcdf_file, var_name, bbox, WIDTH, HEIGHT, CRS)
normalized_data = normalize_data(interpolated_data)
except ValueError as e:
print(f"Error: {e}")
# Return a fully transparent image if no data is available
normalized_data = np.full([HEIGHT, WIDTH], np.nan)
image = create_color_mapped_image(
normalized_data, WIDTH, HEIGHT)
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format=FORMAT.split("/")[1].upper())
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type=FORMAT)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))