Source code for torchsample.coord

import torch
import torch.nn.functional as F

from . import default
from ._nobatch import nobatch
from ._sample import sample

[docs]def tensor_to_size(tensor): """Field size from tensor.""" return tensor.shape[:1:-1]
[docs]def feat_first(tensor): """Move the features (last dim) to dim1.""" return tensor.movedim(-1, 1)
[docs]def feat_last(tensor): """Move the features (dim1) to last dim.""" return tensor.movedim(1, -1)
[docs]def unnormalize(coord, size, align_corners=default.align_corners, clip=False): """Unnormalize normalized coordinates. Modified from PyTorch C source: Parameters ---------- coord : torch.Tensor Values in range ``[-1, 1]`` to unnormalize. size : int or tuple Unnormalized side length in pixels. If tuple, order is ``(x, y, ...)`` to match ``coord``. align_corners : bool if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. clip : bool If ``True``, output will be clipped to range ``[0, size-1]`` Returns ------- torch.Tensor Unnormalized coordinates. """ if coord.numel() == 0: return coord if isinstance(size, int): size = coord.new_tensor(size) size = size[(None,) * coord.ndim] else: # List, Tuple size = coord.new_tensor(size) size = size[(None,) * (coord.ndim - 1)] if align_corners: # unnormalize coord from [-1, 1] to [0, size - 1] unnorm = ((coord + 1) / 2) * (size - 1) else: # unnormalize coord from [-1, 1] to [-0.5, size - 0.5] unnorm = ((coord + 1) * size - 1) / 2 if clip: unnorm = torch.clip(unnorm, 0, size - 1) return unnorm
[docs]def normalize( coord, size, align_corners=default.align_corners, ): """Normalize unnormalized coordinates. Related ------- unnormalize Parameters ---------- coord : torch.Tensor Values in range ``[0, size-1]`` to normalize. size : int or tuple Unnormalized side length in pixels. If tuple, order is ``(x, y, ...)`` to match ``coord``. align_corners : bool if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. Returns ------- torch.Tensor Normalized coordinates. """ if coord.numel() == 0: return coord if isinstance(size, int): size = coord.new_tensor(size) size = size[(None,) * coord.ndim] else: # List, Tuple size = coord.new_tensor(size) size = size[(None,) * (coord.ndim - 1)] if align_corners: norm = (coord / (size - 1)) * 2 - 1 else: norm = (coord * 2 + 1) / size - 1 return norm
[docs]def rand(batch, n_samples, dims=2, dtype=None, device=None): """Generate random coordinates in range ``[-1, 1]``. Parameters ---------- batch : int Batch size. If ``0``, don't return a batch dimension. n_samples : int Number of coordinates to generate per exemplar. dims : int ``2`` for generating ``(x, y)`` coordinates. ``3`` for generating ``(x, y, z)`` coordinates. Defaults to ``2``. dtype : torch.dtype The desired data type of returned tensor. device : torch.device The desired device of returned tensor Returns ------- (batch, n_samples, dims) random coordinates in range ``[-1, 1]``. """ if batch == 0: return 2 * torch.rand(n_samples, dims, dtype=dtype, device=device) - 1 else: return 2 * torch.rand(batch, n_samples, dims, dtype=dtype, device=device) - 1
[docs]def randint( batch, n_samples, size, device=None, align_corners=default.align_corners, replace=False, ): """Generate random pixel coordinates in range ``[-1, 1]``. Unlike ``rand``, the resulting coordinates will land exactly on pixels. Intended for sampling high resolution GT data during data loading. Due to numerical precision, it is suggested to use ``mode="nearest"`` when sampling featuremaps of resolution ``size`` using these coordinates. Parameters ---------- batch : int Batch size. If ``0``, don't return a batch dimension. n_samples : int size : tuple Size of field to generate pixel coordinates for. i.e. ``(x, y, ...)``. device : torch.device The desired device of returned tensor align_corners : bool if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. replace : bool Sample with or without replacement. Defaults to ``False``. Returns ------- (batch, n_samples, dims) random coordinates in range ``[-1, 1]``. """ return_nobatch = False if batch == 0: batch = 1 return_nobatch = True if replace: coords = [] for s in size: unnorm = torch.randint(s, size=(batch, n_samples), device=device) norm = normalize(unnorm, s, align_corners) coords.append(norm) coords = torch.stack(coords, dim=-1) else: unnorm = torch.cartesian_prod( *(torch.arange(x) for x in size), ).to(device) coords = normalize(unnorm, size, align_corners) coords = coords[None].repeat(batch, 1, 1) n_elements = coords.shape[1] batch_select = torch.arange(batch, device=device)[:, None] batch_select = batch_select.repeat(1, n_samples) rand_coord_indices = torch.stack( [ torch.randperm(n_elements, device=device)[:n_samples] for _ in range(batch) ], dim=0, ) coords = coords[batch_select, rand_coord_indices] if return_nobatch: return coords[0] else: return coords
[docs]@nobatch def randint_like(n_samples, tensor, align_corners=default.align_corners, replace=False): """Generate random pixel coordinates in range ``[-1, 1]``. See ``randint``. Parameters ---------- n_samples : int tensor : tuple Tensor to generate random coordinates for. Coordinates will be placed on same device as ``tensor``. align_corners : bool if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. replace : bool Sample with or without replacement. Defaults to ``True``. Returns ------- (batch, n_samples, dims) random coordinates in range ``[-1, 1]``. """ batch = tensor.shape[0] size = tensor_to_size(tensor) return randint( batch, n_samples, size, device=tensor.device, align_corners=align_corners, replace=replace, )
[docs]@nobatch def full(batch, size, device=None, align_corners=default.align_corners): """Generate 2D or 3D coordinates to fully n_samples an image. Parameters ---------- batch : int Batch size. If ``0``, don't return a batch dimension. size : tuple Size of field to generate pixel coordinates for. i.e. ``(x, y, ...)``. device : torch.device The desired device of returned tensor align_corners : bool if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. Returns ------- coords : torch.Tensor ``(n, h, w, 2)`` Normalized coordinates for sampling functions. """ unnormalized = torch.meshgrid(*[torch.arange(x) for x in size], indexing="xy") normalized = [ normalize(*x, align_corners=align_corners) for x in zip(unnormalized, size) ] normalized = torch.stack(normalized, -1).to(device) if batch: normalized = normalized[None] expand = tuple([1] * len(size)) normalized = normalized.repeat(batch, 1, *expand) return normalized
[docs]@nobatch def full_like(tensor, *args, **kwargs): batch = tensor.shape[0] size = tensor_to_size(tensor) return full(batch, size, *args, device=tensor.device, **kwargs)
[docs]def rand_biased( batch, n_samples, pred, k=3.0, beta=0.75, align_corners=default.align_corners, ): """Uncertainty-based point sampling. Described in training part of section 3.1 of: PointRend: Image Segmentation as Rendering Parameters ---------- batch : int Batch size. If ``0``, don't return a batch dimension. n_samples : int Number of coordinates to generate per exemplar. pred : torch.Tensor ``(b, c, h, w)`` or ``(b, c, d, h, w)`` prediction probabilities. k : float Multiplier for oversampling for biased coordinates. beta : float Portion of coordinates to be biased towards uncertain regions. Lower values of beta are more similar to a uniform random distribution, ignoring ``pred``. align_corners : bool if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. Returns ------- torch.Tensor (b, n_samples, dim) Normalized coordinates. """ if pred.ndim == 4: dim = 2 elif pred.ndim == 5: dim = 3 else: raise ValueError(f"Cannot handle pred dimensionality {pred.ndim}.") return_nobatch = False if batch == 0: batch = 1 return_nobatch = True if k < 1: raise ValueError('"k" must be >=1') if not (0 <= beta <= 1): raise ValueError('"beta" must be in range [0, 1]') coords = pred.new_empty((batch, n_samples, dim)) n_biased = int(beta * n_samples) n_unbiased = n_samples - n_biased if n_biased: oversampled_coords = rand(batch, int(k * n_samples), dim, device=coords.device) pred_sampled = sample(oversampled_coords, pred) certainty = torch.max(pred_sampled, dim=-1).values # (b, k*n_samples) _, indices = torch.sort(certainty, dim=1) indices = indices[:, :n_biased] coords[:, :n_biased] = oversampled_coords[ torch.arange(batch)[:, None].expand(-1, n_biased), indices ] if n_unbiased: # This won't happen during testing when ``beta==0`` coords[:, n_biased:] = rand(batch, n_unbiased, dim, device=coords.device) if return_nobatch: return coords[0] else: return coords
[docs]def uncertain( pred, threshold=0.5, align_corners=default.align_corners, ): """Report all coordinates with uncertain scores. Used for inference with models trained with ``rand_biased``. .. code-block:: python features = feature_extractor(batch["image"]) ss_pred = decoder(batch["image"]) index_coords, norm_coords = ts.coord.uncertain(ss_pred) features_sampled = ts.sample(norm_coords, features) refined = mlp(features_sampled) ss_pred_refined = ss_pred.clone() ss_pred_refined[index_coords] = ts.feat_first(refined).flatten() Parameters ---------- pred : torch.Tensor ``(1, c, h, w)`` or ``(1, c, d, h, w)`` prediction probabilities. Can NOT handle ``>1`` batch size; would result in ragged tensors. threshold : float Coords with a max probability lower than this threshold will have coordinates generated. align_corners : bool if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. Returns ------- index_coords : tuple Integer pixel coordinates that can be used to directly index into ``pred``. norm_coords Normalized coordinates to be used with ``ts.sample``. Does NOT have a batch dimension, as it would end up with a ragged tensor. """ size = tensor_to_size(pred) certainty = pred.max(1, keepdim=True).values # Expand so when indexing we also index into the all the features. certainty = certainty.expand(pred.shape) # (n_coords, certainty.ndim) mask = certainty < threshold index_coords = mask.nonzero(as_tuple=True) index_coords_no_c = mask[:, 0].nonzero(as_tuple=True) # Get rid of the batch index and flip the remaining # (..., y, x) to get (x, y, ...) pixel_coords = torch.stack(index_coords_no_c[:0:-1], -1)[None] norm_coords = normalize(pixel_coords, size, align_corners) return index_coords, norm_coords