mml.core.visualization.predictions

Functions that allow nice plotting of model predictions.

render_classification_predictions(raw_images: Tensor, logits: Tensor, targets: Tensor, classes: List[str]) Figure[source]

Implements prediction rendering for classification tasks.

render_labeled_grid(img_grid: Sequence[Sequence[Tensor | ndarray | str]], col_labels: List[str] | None) Figure[source]

Takes a grid of images and strings and returns a matplotlib figure.

Parameters:
  • img_grid (Sequence[Sequence[Union[torch.Tensor, np.ndarray, str]]]) – list of lists, containing images as numpy arrays or torch tensors and potentially strings

  • col_labels (Optional[List[str]]) – (optional) labels for the columns of the grid

Returns:

a matplotlib figure with given column titles and rendered images / text

Return type:

plt.Figure

render_predictions(raw_images: Tensor, logits: Tensor, targets: Tensor, classes: List[str], task_type: TaskType) Figure[source]

Wrapper function to access task prediction renderers.

Parameters:
  • raw_images (torch.Tensor) – non-normalized but potentially augmented (e.g. rotated) images

  • logits (torch.Tensor) – prediction logits

  • targets (torch.Tensor) – underlying targets

  • classes (List[str]) – class strings (order must match target indices)

  • task_type (TaskType) – the corresponding tasks task type

Returns:

a matplotlib figure that shows some model predictions

render_segmentation_predictions(raw_images: Tensor, logits: Tensor, targets: Tensor, classes: List[str]) Figure[source]

Implements prediction rendering for segmentation tasks.