# -*- coding: utf-8
""" Classes for loading digital elevation models as numeric grids """
import numexpr
import numpy as np
import os
import subprocess
import sys
import matplotlib
import matplotlib.pyplot as plt
from copy import copy
from osgeo import gdal, gdalconst
from rasterio.fill import fillnodata
from scarplet.utils import BoundingBox
sys.setrecursionlimit(10000)
FLOAT32_MIN = np.finfo(np.float32).min
GDAL_DRIVER_NAME = 'GTiff'
[docs]class CalculationMixin(object):
"""Mix-in class for grid calculations"""
def __init__(self):
pass
def _calculate_slope(self):
"""Calculate gradient of grid in x and y directions.
Pads boundary so as to return slope grids of same size as object's
grid data
Returns
-------
slope_x : numpy array
slope in x direction
slope_y : numpy array
slope in y direction
"""
dx = self._georef_info.dx
dy = self._georef_info.dy
PAD_DX = 2
PAD_DY = 2
self._pad_boundary(PAD_DX, PAD_DY)
z_pad = self._griddata
slope_x = (z_pad[1:-1, 2:] - z_pad[1:-1, :-2]) / (2 * dx)
slope_y = (z_pad[2:, 1:-1] - z_pad[:-2, 1:-1]) / (2 * dy)
return slope_x, slope_y
def _calculate_laplacian(self):
"""Calculate curvature of grid in y direction.
"""
return self._calculate_directional_laplacian(0)
def _calculate_directional_laplacian(self, alpha):
"""Calculate curvature of grid in arbitrary direction.
Parameters
----------
alpha : float
direction angle (azimuth) in radians. 0 is north or y-axis.
Returns
-------
del2s : numpy array
grid of curvature values
"""
dx = self._georef_info.dx
dy = self._georef_info.dy
z = self._griddata
nan_idx = np.isnan(z)
z[nan_idx] = 0
dz_dx = np.diff(z, 1, 1)/dx
d2z_dxdy = np.diff(dz_dx, 1, 0)/dx
pad_x = np.zeros((d2z_dxdy.shape[0], 1))
d2z_dxdy = np.hstack([pad_x, d2z_dxdy])
pad_y = np.zeros((1, d2z_dxdy.shape[1]))
d2z_dxdy = np.vstack([pad_y, d2z_dxdy])
d2z_dx2 = np.diff(z, 2, 1)/dx**2
pad_x = np.zeros((d2z_dx2.shape[0], 1))
d2z_dx2 = np.hstack([pad_x, d2z_dx2, pad_x])
d2z_dy2 = np.diff(z, 2, 0)/dy**2
pad_y = np.zeros((1, d2z_dy2.shape[1]))
d2z_dy2 = np.vstack([pad_y, d2z_dy2, pad_y])
del2z = d2z_dx2 * np.cos(alpha) ** 2 - 2 * d2z_dxdy * np.sin(alpha) \
* np.cos(alpha) + d2z_dy2 * np.sin(alpha) ** 2
del2z[nan_idx] = np.nan
return del2z
def _calculate_directional_laplacian_numexpr(self, alpha):
"""Calculate curvature of grid in arbitrary direction.
Optimized with numexpr expressions.
Parameters
----------
alpha : float
direction angle (azimuth) in radians. 0 is north or y-axis.
Returns
-------
del2s : numpyarray
grid of curvature values
"""
dx = self._georef_info.dx
dy = self._georef_info.dy
z = self._griddata
nan_idx = np.isnan(z)
z[nan_idx] = 0
dz_dx = np.diff(z, 1, 1)/dx
d2z_dxdy = np.diff(dz_dx, 1, 0)/dx
pad_x = np.zeros((d2z_dxdy.shape[0], 1))
d2z_dxdy = np.hstack([pad_x, d2z_dxdy])
pad_y = np.zeros((1, d2z_dxdy.shape[1]))
d2z_dxdy = np.vstack([pad_y, d2z_dxdy])
d2z_dx2 = np.diff(z, 2, 1)/dx**2
pad_x = np.zeros((d2z_dx2.shape[0], 1))
d2z_dx2 = np.hstack([pad_x, d2z_dx2, pad_x])
d2z_dy2 = np.diff(z, 2, 0)/dy**2
pad_y = np.zeros((1, d2z_dy2.shape[1]))
d2z_dy2 = np.vstack([pad_y, d2z_dy2, pad_y])
del2z = numexpr.evaluate("d2z_dx2*cos(alpha)**2 - \
2*d2z_dxdy*sin(alpha)*cos(alpha) + d2z_dy2*sin(alpha)**2")
del2z[nan_idx] = np.nan
return del2z
def _estimate_curvature_noiselevel(self):
"""Estimate noise level in curvature of grid as a function of direction.
Returns
-------
angles : numpy array
array of orientations (azimuths) in radians
mean : float
array of mean curvature in correponding direction
sd : float
array of curvature standard deviation in correponding direction
"""
from scipy import ndimage
angles = np.linspace(0, np.pi, num=180)
mean = []
sd = []
for alpha in angles:
del2z = self._calculate_directional_laplacian(alpha)
lowpass = ndimage.gaussian_filter(del2z, 100)
highpass = del2z - lowpass
mean.append(np.nanmean(highpass))
sd.append(np.nanstd(highpass))
return angles, mean, sd
def _pad_boundary(self, dx, dy):
"""Pad grid boundary with reflected boundary conditions.
"""
self._griddata = np.pad(self._griddata, pad_width=(dy, dx),
mode='reflect')
self.padded = True
self.pad_dx = dx
self.pad_dy = dy
ny, nx = self._griddata.shape
self._georef_info.nx = nx
self._georef_info.ny = ny
self._georef_info.xllcenter -= dx
self._georef_info.yllcenter -= dy
[docs]class GDALMixin(object):
pass
[docs]class GeorefInfo(object):
def __init__(self):
self.geo_transform = None
self.projection = None
self.xllcenter = None
self.yllcenter = None
self.dx = None
self.dy = None
self.nx = None
self.ny = None
self.ulx = None
self.uly = None
self.lrx = None
self.lry = None
[docs]class BaseSpatialGrid(GDALMixin):
"""Base class for spatial grid"""
dtype = gdalconst.GDT_Float32
def __init__(self, filename=None):
_georef_info = GeorefInfo()
if filename is not None:
self._georef_info = _georef_info
self.load(filename)
self.filename = filename
else:
self.filename = None
self._georef_info = _georef_info
self._griddata = np.empty((0, 0))
[docs] def is_contiguous(self, grid):
"""Returns true if grids are contiguous or overlap
Parameters
----------
grid : BaseSpatialGrid
"""
return self.bbox.intersects(grid.bbox)
[docs] def merge(self, grid):
"""Merge this grid with another BaseSpatialGrid.
Wrapper argound gdal_merge.py.
Parameters
----------
grid : BaseSpatialGrid
Returns
-------
merged_grid : BaseSpatialGrid
"""
if not self.is_contiguous(grid):
raise ValueError("ValueError: Grids are not contiguous")
# XXX: this is hacky, eventually implement as native GDAL
try:
command = ['gdal_merge.py', self.filename, grid.filename]
subprocess.check_output(command, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
print("Failed to merge grids. gdal_merge may not be installed.")
raise e
merged_grid = BaseSpatialGrid('out.tif')
merged_grid._griddata[merged_grid._griddata == FLOAT32_MIN] = np.nan
os.remove('out.tif')
return merged_grid
[docs] def plot(self, **kwargs):
"""Plot grid data
Keyword args:
Any valid keyword argument for matplotlib.pyplot.imshow
"""
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.imshow(self._griddata, **kwargs)
[docs] def save(self, filename):
"""Save grid as georeferenced TIFF
"""
ncols = self._georef_info.nx
nrows = self._georef_info.ny
driver = gdal.GetDriverByName(GDAL_DRIVER_NAME)
out_raster = driver.Create(filename, ncols, nrows, 1, self.dtype)
out_raster.SetGeoTransform(self._georef_info.geo_transform)
out_band = out_raster.GetRasterBand(1)
out_band.WriteArray(self._griddata)
proj = self._georef_info.projection.ExportToWkt()
out_raster.SetProjection(proj)
out_band.FlushCache()
[docs] def load(self, filename):
"""Load grid from file
"""
self.label = filename.split('/')[-1].split('.')[0]
gdal_dataset = gdal.Open(filename)
band = gdal_dataset.GetRasterBand(1)
nodata = band.GetNoDataValue()
self._griddata = band.ReadAsArray().astype(float)
if nodata is not None:
nodata_index = np.where(self._griddata == nodata)
if self.dtype is not np.uint8:
self._griddata[nodata_index] = np.nan
geo_transform = gdal_dataset.GetGeoTransform()
projection = gdal_dataset.GetProjection()
nx = gdal_dataset.RasterXSize
ny = gdal_dataset.RasterYSize
self._georef_info.geo_transform = geo_transform
self._georef_info.projection = projection
self._georef_info.dx = self._georef_info.geo_transform[1]
self._georef_info.dy = self._georef_info.geo_transform[5]
self._georef_info.nx = nx
self._georef_info.ny = ny
self._georef_info.xllcenter = self._georef_info.geo_transform[0] \
+ self._georef_info.dx
self._georef_info.yllcenter = self._georef_info.geo_transform[3] \
- (self._georef_info.ny+1) \
* np.abs(self._georef_info.dy)
self._georef_info.ulx = self._georef_info.geo_transform[0]
self._georef_info.uly = self._georef_info.geo_transform[3]
self._georef_info.lrx = self._georef_info.geo_transform[0] \
+ self._georef_info.dx * self._georef_info.nx
self._georef_info.lry = self._georef_info.geo_transform[3] \
+ self._georef_info.dy * self._georef_info.ny
self.bbox = BoundingBox((self._georef_info.lrx, self._georef_info.lry),
(self._georef_info.ulx, self._georef_info.uly))
[docs]class DEMGrid(CalculationMixin, BaseSpatialGrid):
"""Class representing grid of elevation values"""
def __init__(self, filename=None):
_georef_info = GeorefInfo()
if filename is not None:
self._georef_info = _georef_info
self.load(filename)
self._griddata[self._griddata == FLOAT32_MIN] = np.nan
self.nodata_value = np.nan
self.filename = filename
self.shape = self._griddata.shape
self.is_interpolated = False
else:
self.filename = None
self.label = ''
self.shape = (0, 0)
self._georef_info = _georef_info
self._griddata = np.empty((0, 0))
self.is_interpolated = False
[docs] def plot(self, color=True, **kwargs):
fig, ax = plt.subplots(1, 1, **kwargs)
hs = Hillshade(self)
hs.plot()
if color:
im = ax.imshow(self._griddata, alpha=0.75, cmap='terrain')
plt.colorbar(im, ax=ax, shrink=0.75, label='Elevation')
ax.tick_params(direction='in')
ax.set_xlabel('x')
ax.set_ylabel('y')
def _fill_nodata(self):
"""Fill nodata values in elevation grid by interpolation.
Wrapper around GDAL/rasterio's FillNoData, fillnodata methods
"""
if ~np.isnan(self.nodata_value):
nodata_mask = self._griddata == self.nodata_value
else:
nodata_mask = np.isnan(self._griddata)
self.nodata_mask = nodata_mask
# XXX: GDAL (or rasterio) FillNoData takes mask with 0s at nodata
num_nodata = np.sum(nodata_mask)
prev_nodata = np.nan
while num_nodata > 0 or num_nodata == prev_nodata:
mask = np.isnan(self._griddata)
col_nodata = np.sum(mask, axis=0).max()
row_nodata = np.sum(mask, axis=1).max()
dist = max(row_nodata, col_nodata) / 2
self._griddata = fillnodata(self._griddata,
mask=~mask,
max_search_distance=dist)
prev_nodata = copy(num_nodata)
num_nodata = np.sum(np.isnan(self._griddata))
self.is_interpolated = True
def _fill_nodata_with_edge_values(self):
"""Fill nodata values using swath edge values by row,"""
if ~np.isnan(self.nodata_value):
nodata_mask = self._griddata == self.nodata_value
else:
nodata_mask = np.isnan(self._griddata)
self.nodata_mask = nodata_mask
for row in self._griddata:
idx = np.where(np.isnan(row)).min()
fill_value = row[idx]
row[np.isnan(row)] = fill_value
self.is_interpolated = True
[docs]class Hillshade(BaseSpatialGrid):
"""Class representing hillshade of DEM"""
def __init__(self, dem):
"""Load DEMGrid object as Hillshade
"""
self._georef_info = dem._georef_info
self._griddata = dem._griddata
self._hillshade = None
[docs] def plot(self, az=315, elev=45):
"""Plot hillshade
Paramaters
----------
az : float
azimuth of light source in degrees
elev : float
elevation angle of light source in degrees
"""
ax = plt.gca()
ls = matplotlib.colors.LightSource(azdeg=az, altdeg=elev)
self._hillshade = ls.hillshade(self._griddata, vert_exag=1,
dx=self._georef_info.dx,
dy=self._georef_info.dy)
ax.imshow(self._hillshade, alpha=1, cmap='gray', origin='lower')