mml.core.visualization.predictions
Functions that allow nice plotting of model predictions.
- render_classification_predictions(raw_images: Tensor, logits: Tensor, targets: Tensor, classes: List[str]) Figure[source]
Implements prediction rendering for classification tasks.
- render_labeled_grid(img_grid: Sequence[Sequence[Tensor | ndarray | str]], col_labels: List[str] | None) Figure[source]
Takes a grid of images and strings and returns a matplotlib figure.
- Parameters:
img_grid (Sequence[Sequence[Union[torch.Tensor, np.ndarray, str]]]) – list of lists, containing images as numpy arrays or torch tensors and potentially strings
col_labels (Optional[List[str]]) – (optional) labels for the columns of the grid
- Returns:
a matplotlib figure with given column titles and rendered images / text
- Return type:
plt.Figure
- render_predictions(raw_images: Tensor, logits: Tensor, targets: Tensor, classes: List[str], task_type: TaskType) Figure[source]
Wrapper function to access task prediction renderers.
- Parameters:
raw_images (torch.Tensor) – non-normalized but potentially augmented (e.g. rotated) images
logits (torch.Tensor) – prediction logits
targets (torch.Tensor) – underlying targets
classes (List[str]) – class strings (order must match target indices)
task_type (TaskType) – the corresponding tasks task type
- Returns:
a matplotlib figure that shows some model predictions