I’m working on a deep learning problem where the goal is to establish an injective and surjective mapping between elements of an input sequence and elements of an output sequence.
The input sequence consists of feature representations of jigsaw puzzle pieces (extracted by EfficientNetV2L-S, and the output elements are tuples of row, column, and rotation indices in their respective coordinate systems:
Nmax, Mmax being the number of rows and columns of the Jigsaw puzzle, and the last label describes the rotation.
- Injective: Each element in the input sequence (jigsaw puzzle piece) must map to a unique (row, col) index in the output sequence, ensuring that no two inputs map to the same (row, col) position.
- Surejective: Every possible (row, col) index in the output sequence must be assigned to an input element. This ensures that the mapping covers the entire range of possible (row, col) positions, meaning all positions in the puzzle are occupied by a piece.
Fortunately both injectivity and surjectivity go hand-in-hand for the given problem. That means that we can ensure both properties by enforcing either of them!
Note: I’m aware that the problem of Jigsaw Solving is probably easier solved with an algorithmic approach, though, I simply want to solve it by means of DL! Same goes for the use of a Transformer – I simply took the problem as an opportunity to work with Transformers for the first time!
Existing Work:
- JigsawPuzzlePytorch: Uses a much simpler approach, trains a CNN from scratch and uses a standard CNN + Classifier architecture. I did not really consider this approach, since it seems to be quite naive. Though, I could of course be wrong and it might be that my current approach is way too complex.
- Deepzzle: Key contributions and useful approaches from this paper include:
- Pairwise Compatibility: Compute pairwise compatibility scores between puzzle pieces using CNN features.
- Shortest Path Optimization: Formulate the puzzle assembly as a shortest path problem, solved via Dijkstra’s algorithm.
- Hybrid Method: Combine deep learning for feature extraction with graph-based optimization for puzzle assembly, balancing both learning and optimization techniques.
I utilized piecemaker to generate a dataset based on ImageNet. See a random augmented and reconstructed train image below.
Current Problems & Considerations:
With my current architecture, which can be found below, I seem to get stuck in local minima somewhen in the first epoch, thus I’d like to consider alternative approaches. The problem is that the cardinality of the space of possible hyper-parameters and architectural changes is way too vast to explore it manually (or naive Optuna sweeps) on my RTX3080Ti. I’d like to consider novel approaches to the problem, but I’m of course also open to any suggestions on how my current architecture can be improved!
-
Ensuring Unique and Valid Outputs: How can we ensure the output sequence always consists of unique and valid coordinates, where no coordinate is predicted more than once?
What architectures are capable of performing an injective and surjective mapping between input and output sequences? Pointer Networks seem to be a good fit, but I’m not sure about their current state-of-the-art implementations. I think that an adaptation of Pointer Networks to a Transformer architecture could be a good approach.
The problem also screams for a graph-based approach, though I’m not familiar with GraphNNs, so I’m not really having a clue on how this could work. -
Incorporating a Linear Assignment Problem (LAP) Solver: Linear assignment problems aim to optimally assign resources (or in this case, positions) to agents (puzzle pieces) such that the cost is minimized. In the given context, minimizing the joint negative log-likelihoods of the joint row&col probs could be a reasonable approach to get unique assignments. A LAP solver like the Hungarian algorithm or
scipy.optimize.linear_sum_assignment
could be used to find the assignment that minimizes overlap and maximizes the overall probability. However, this is not differentiable and is thus hard to incorporate into the training process. While there exists a PyTorch implementation of the Hungarian algorithm (hungarian-net), I’m not sure about its efficacy, since it seems to be quite primitive.
Objective:
We aim to minimize the combined loss function:
where
Non-Unique Loss:
The loss function that aims to ensure injectivity looks like this:
.
Questions: This loss seemed to dominate the optimization and the model usually gets stuck in a local minimum, due to the MSE-pos
loss being overshadowed by the non-unique
loss. To deal with this, I implemented a Hybrid Learning Rate Scheduler, which combines the ReduceLROnPlateau
and CyclicLR
schedulers. However, I have no idea if this is a good approach.See a plot of a mock loss and the resulting lr below.
To midigate the impact of the non-unique
loss, I’m also slowly increasing the standard deviation of the Gaussian in the previous equation from 0.1 to 0.52 over the course of training. Trying to make the std. dev. dependent on the non-unique
loss did not work out.
Current Model Architecture: PatchNet
PATCH-Net (Puzzle Assembly by Transformer and CNN Hybrid Network) is designed to solve the jigsaw puzzle problem by mapping feature representations of puzzle pieces to their respective positions and rotations, utilizing Learnable Fourier Embeddings of the predicted row / col / rot indices and the puzzle_type .
PATCH-Net: Puzzle Assembly by Transformer and CNN Hybrid Network
PATCH-Net is designed to solve the jigsaw puzzle problem by mapping feature representations of puzzle pieces to their respective positions and rotations. It utilizes Learnable Fourier Embeddings for the predicted row/col/rot indices and the puzzle type.
-
EfficientNetV2 Backbone:
- Function: Extracts feature representations from input puzzle pieces.
- Input: Tensor of shape [B, num_pieces, 3, H, W]
- Output: Tensor of shape [B, num_pieces, num_features_out]
class EfficientNetV2(nn.Module): def __init__(self, hparams: HParams.Backbone): super().__init__() self.backbone = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1) self.backbone.classifier[1] = nn.Linear(self.backbone.classifier[1].in_features, hparams.num_features_out) if not hparams.is_trainable: for param in self.backbone.parameters(): param.requires_grad_(False) self.backbone.classifier[1].requires_grad_(True) def forward(self, x: Tensor) -> Tensor: B, num_pieces, C, H, W = x.shape x = x.view(B * num_pieces, C, H, W) features = self.backbone(x) features are reshaped to [B, num_pieces, -1] return features
Questions :
– Latent Space Transition: How should we structure the latent space between feature extraction and sequence modeling? Is a linear layer appropriate or should we just flatten or use a 1×1 Conv2d layer?
– State-of-the-Art Architectures: Would pretrained Vision Transformers or hybrid architectures be more suitable than EfficientNetV2Light? Should we use a more complex backbone? EfficientNetV2Light has around 24M parameters, while our current range is between 6M and 120M trainable parameters.
– Current Input Shapes: Puzzle pieces of size 48×48 are fed into the network. Should we consider larger input sizes for better feature extraction, though 48×48 seems to be a valid pick given the original sizes of the pieces.
– Training Strategy: How significant could the impact of retaining the backbone’s weights be on the model’s performance? However, freezing is way fater than fine-tuning of the backbone’s weights. -
Puzzle Type Classifier:
- Function: Predicts the type of each puzzle piece (e.g., edge, corner, center) and provides type embeddings.
- Input: Tensor of shape [B, num_pieces, num_features_out]
- Output: Logits of shape [B, num_pieces, 3] and embeddings of shape [B, num_pieces, num_features_out]
class PuzzleTypeClassifier(nn.Module): def __init__(self, hparams: HParams.TypeClassifier): super().__init__() self.fc = nn.Linear(hparams.input_features, 3) self.embedding = nn.Sequential( nn.Softmax(dim=-1), LearnableFourierFeatures(**hparams.fourier_embedding.model_dump()), ) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: x = self.fc(x) return x, self.embedding(x.unsqueeze(-2))
Questions:
– Type Embeddings: The type embeddings are added to the feature representations. I would assume thatf_dim = num_features_out
andh_dim = 12
, would be somewhat reasonable choices. However, I’m not sure about the impact of the type embeddings on the model’s performance.
The hyperparameters for the Fourier Embedding are:
– pos_dim (int): Dimensionality of the input position.
– f_dim (int): Dimensionality of the Fourier features. Default is 768.
– h_dim (int): Dimensionality of the hidden layer in the MLP. Default is 32.
– d_dim (int): Dimensionality of the output embedding. Default is 768.
– g_dim (int): Number of positional groups. Default is 1.
– gamma (float): Variance scaling factor for initializing the Fourier feature weight matrix. Default is 1.0.
– Complexity of the Classifier Head: The task seems to be quite simple, thus, I assumed that a single classification head would be sufficient. However, could it be appropriate to prepend a small MLP before the classification head to increase the model’s capacity? -
Transformer:
- Function: Encodes and decodes feature representations to predict positions and rotations.
- Input: Source tensor of shape [B, num_pieces, num_features_out], positional encoding of shape [B, num_pieces, num_features_out]
- Output: Decoder output tensor of shape [B, num_pieces, num_features_out] and encoder memory tensor of shape [B, num_pieces, num_features_out]
class Transformer(nn.Module): def __init__(self, hparams: HParams.Transformer): super().__init__() self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=hparams.d_model, nhead=hparams.nhead, batch_first=True, dim_feedforward=hparams.dim_feedforward, activation=F.silu, ), num_layers=hparams.num_encoder_layers, norm=nn.LayerNorm(hparams.d_model), ) self.decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model=hparams.d_model, nhead=hparams.nhead, batch_first=True, dim_feedforward=hparams.dim_feedforward, activation=F.silu, ), num_layers=hparams.num_decoder_layers, norm=nn.LayerNorm(hparams.d_model), ) def forward(self, src: Tensor, pos_encoding: Tensor, memory: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: encoder_memory = self.encoder(src) if memory is None else memory tgt_mask = None if self.training: tgt_mask = nn_utils.generate_causal_mask(pos_encoding.size(1), pos_encoding.device) decoder_output = self.decoder(pos_encoding, encoder_memory, tgt_mask=tgt_mask if self.training else None) return decoder_output, encoder_memory
Questions:
- Bidirectional Decoder: Currently the decoder sequence is processed unidirectional, thus, we introduce an unwanted priority to the tokens that come earlier on in the sequence. Could the missing bidirectionality boost our model capacity dramatically?
- Complexity of the Transformer: I manually tested a few hyperparameters:
num_encoder_layers
/num_decoder_layers
: 2..6d_model
: 512…1024dim_feedforward
: 1024…2048
which resulted in 6M…120M trainable params, but I could not really see any significant improvements. Does anyone have any intuition on reasonable magnitues for these hyperparameters?
- Skip Connections: Is it a common practice to use skip connections in the Transformer architecture? Maybe skipping (HiddenLayer <- EncoderOutput + EncoderInput)?
- Remove Encoder or Decoder: Could it make sense to remove either the encoder or decoder? I’m not sure about the necessity of the encoder, since we should already have a good feature representation from the EfficientNetV2 backbone.
-
Dynamic Index Classifier:
- Class: DynamicIdxClassifier
- Function: Predicts row, column, and rotation logits for each puzzle piece.
- Input: Tensor of shape [B, num_pieces, num_features_out], actual rows and columns in the puzzle
- Output: Row logits of shape [B, num_pieces, max_rows], column logits of shape [B, num_pieces, max_cols], rotation logits of shape [B, num_pieces, 4]
class DynamicIdxClassifier(nn.Module): def __init__(self, hparams: HParams.IdxClassifier): super().__init__() self.max_rows = hparams.max_rows self.max_cols = hparams.max_cols self.fc_rows = nn.Linear(hparams.input_features, self.max_rows) self.fc_cols = nn.Linear(hparams.input_features, self.max_cols) self.fc_rot = nn.Linear(hparams.input_features, 4) def forward(self, x: Tensor, actual_rows: int, actual_cols: int) -> Tuple[Tensor, Tensor, Tensor]: row_logits = self.fc_rows(x) col_logits = self.fc_cols(x) rot_logits = self.fc_rot(x) row_logits[..., actual_rows:].fill_(float("-inf")) col_logits[..., actual_cols:].fill_(float("-inf")) return row_logits, col_logits, rot_logits
Questions:
– So far I’ve only used puzzles of shape 3×4, though I’d like to extend the model to handle puzzles of arbitrary shape. Is the current approach of masking the logits for the actual rows and columns a good way to handle this?
– Could it make sense to embedd the actual rows and columns somehow into the network?
– Would it make more sense to predict both the row and column indices jointly, by using a single classification headfc_joint = nn.Linear(hparams.input_features, self.max_rows * self.max_cols)
, or maybe even use both approaches and combine the joint probabilities / logits in the end? -
PatchNet:
- Function: Integrates all the components to solve the jigsaw puzzle problem.
- Workflow:
- Feature Extraction: Input tensor of shape
[B, num_pieces, 3, H, W]
is passed through the EfficientNetV2 backbone to obtain features of shape[B, num_pieces, num_features_out]
. - Type Classification: The features are passed through the PuzzleTypeClassifier to predict the type and get type embeddings.
- Embedding Addition: The type embeddings are added to the feature representations.
- Position Sequence Initialization: For training, the true position sequence is embedded and concatenated with the start-of-sequence token.
- Transformer Encoding and Decoding: The initialized position sequence and feature representations are processed by the Transformer encoder and decoder.
- Decoder Sequence: The decoder sequence consists of the start-of-sequence token followed by the fourier embedded true position sequence (during training) or the autoregressively generated sequence (during inference). This sequence is used as input to the Transformer decoder.
- Classification: The output from the Transformer is passed to the DynamicIdxClassifier to predict row, column, and rotation logits.
- Feature Extraction: Input tensor of shape
class PatchNet(nn.Module): def __init__(self, hparams: HParams): super().__init__() self.hparams = hparams self.backbone = EfficientNetV2(hparams.backbone) self.puzzle_type_classifier = PuzzleTypeClassifier(hparams.type_classifier) self.transformer = Transformer(hparams.transformer) self.classifier = DynamicIdxClassifier(hparams.idx_classifier) self.spatial_embedding = LearnableFourierFeatures(**hparams.fourier_embedding_pos.model_dump()) self.rotation_embedding = LearnableFourierFeatures(**hparams.fourier_embedding_rot.model_dump()) self.start_of_seq_token = nn.Parameter(torch.randn(1, 1, hparams.transformer.d_model, dtype=torch.float32), requires_grad=True) def forward(self, x: Tensor, true_pos_seq: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tuple[Tensor, Tensor, Tensor]]: x = self.backbone(x) puzzle_type_logits, puzzle_type_embedding = self.puzzle_type_classifier.forward(x) x = x + puzzle_type_embedding if self.training: assert true_pos_seq is not None pos_seq = torch.cat([self.start_of_seq_token.expand(x.size(0), 1, -1), self._embedd_pos_seq(true_pos_seq.clone().to(x))], dim=1) pos_seq, logits = self._soft_forward_step(x, pos_seq) else: pos_seq, logits = self._autoregressive_decode(x) def remove_start_token(x): if not self.training: return x if isinstance(x, tuple): return tuple(map(remove_start_token, x)) return x[:, 1:, ...] return (puzzle_type_logits, *remove_start_token((pos_seq, logits))) def _soft_forward_step(self, x: Tensor, pos_seq: Tensor, encoder_memory: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: x, encoder_memory = self.transformer(x, pos_seq, encoder_memory) logits = self.classifier(x, *self.hparams.puzzle_shape) joint_probs = nn_utils.compute_joint_probabilities(*logits[:2]) joint_logits = nn_utils.apply_penalties(joint_probs) pos_seq = torch.cat([nn_utils.softargmax2d(joint_logits), nn_utils.softargmax1d(logits[2]).unsqueeze(-1)], dim=-1) return pos_seq, logits def _embedd_pos_seq(self, pos: Tensor) -> Tensor: spatial_encoding = self.spatial_embedding(pos[:, :, :2].unsqueeze(-2)) rotation_encoding = self.rotation_embedding(pos[:, :, 2:].unsqueeze(-2)) return spatial_encoding + rotation_encoding def _autoregressive_decode(self, x: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]]: B, L, _F = x.shape pos_seq = torch.cat([self.start_of_seq_token.expand(B, 1, -1)], dim=1) encoder_memory = None logits_list: list[tuple[Tensor, Tensor, Tensor]] = [] for _token in range(L): decoder_output, encoder_memory = self.transformer(x, pos_seq, encoder_memory) logits = self.classifier(decoder_output[:, -1, :], *self.hparams.puzzle_shape) row_logits, col_logits, rot_logits = logits logits_list.append(logits) if row_logits.dim() == 2: row_logits = row_logits.unsqueeze(1) col_logits = col_logits.unsqueeze(1) rot_logits = rot_logits.unsqueeze(1) joint_probs = nn_utils.compute_joint_probabilities(row_logits, col_logits) next_token = torch.cat([nn_utils.argmax2d(joint_probs), torch.argmax(rot_logits, dim=-1, keepdim=True)], dim=-1).to(torch.float32) pos_seq = torch.cat([pos_seq, self._embedd_pos_seq(next_token)], dim=1) logits = tuple(torch.stack(t, dim=1) for t in zip(*logits_list)) row_logits, col_logits, _ = logits joint_probs = nn_utils.compute_joint_probabilities(row_logits, col_logits) joint_probs = nn_utils.apply_penalties(joint_probs) pos_seq = torch.cat([nn_utils.softargmax2d(joint_probs), nn_utils.softargmax1d(logits[2]).unsqueeze(-1)], dim=-1) return pos_seq, logits
Questions:
_autoregressive_decode
:nn_utils.apply_penalties(joint_probs)
does not really work the way it is used here, since it should be applied to the full sequence of- How could we incorporate the non-differentiable
scipy.optimize.linear_sum_assignment
into the training process?
The PatchNet
model uses some utilitiy functions from nn_utils.py
:
class nn_utils:
@staticmethod
def compute_joint_probabilities(row_logits: Tensor, col_logits: Tensor) -> Tensor:
"""
Args:
row_logits (Tensor[B, num_pieces, num_rows])
col_logits (Tensor[B, num_pieces, num_cols])
Returns:
Tensor[B, num_pieces, num_rows, num_cols]: Joint probabilities
"""
# Compute probabilities within each token over all classes
row_probs = F.softmax(row_logits, dim=-1)
col_probs = F.softmax(col_logits, dim=-1)
joint_probs = row_probs[:, :, :, None] * col_probs[:, :, None, :]
return joint_probs
@staticmethod
def apply_penalties(joint_probs: Tensor) -> Tensor:
"""
Aims to penalize high probabilities for the same coordinates
Args:
joint_probs: Tensor [B, L, num_rows, num_cols] of Joint Probabilities.
Case distinctions:
- max_per_class & max_per_token -> assign
- ~max_per_class & max_per_token -> penalize, subsidize others
- max_per_class & ~max_per_token -> subsidize
"""
flat_probs = joint_probs.view(*joint_probs.shape[:2], -1)
max_probs_per_token, _ = flat_probs.max(dim=1, keepdim=True)
max_probs_per_class, _ = flat_probs.max(dim=-1, keepdim=True)
max_per_token = (flat_probs == max_probs_per_token).float()
max_per_class = (flat_probs == max_probs_per_class).float()
# Apply a dynamic penalty for non-max elements and a boost for max elements where they aren't the max token-wise
# Penalizes non-max tokens
penalties = (
(1 - max_per_class)
* max_per_token
* ((max_probs_per_class - flat_probs) / max_probs_per_class)
)
# Boost max-class elements that are not max-token
incentives = (
(1 - max_per_token)
* max_per_class
* ((max_probs_per_token - flat_probs) / max_probs_per_token)
)
adjusted_probs = flat_probs * (1 + incentives - penalties)
return (
(adjusted_probs + torch.finfo(torch.float32).eps).log().softmax(-1)
).view_as(joint_probs)
@staticmethod
def softargmax1d(input: Tensor, beta=100) -> Tensor:
# https://github.com/david-wb/softargmax
*_, n = input.shape
input = F.softmax(beta * input, dim=-1)
indices = torch.linspace(0, 1, n, device=input.device)
result = torch.sum((n - 1) * input * indices, dim=-1)
return result
@staticmethod
def softargmax2d(input: Tensor, beta=100) -> Tensor:
# https://github.com/david-wb/softargmax
*_, h, w = input.shape
input = input.reshape(*_, h * w)
input = nn.functional.softmax(beta * input, dim=-1)
indices_c, indices_r = torch.meshgrid(
torch.linspace(0, 1, w, device=input.device),
torch.linspace(0, 1, h, device=input.device),
)
indices_r = indices_r.reshape(-1, h * w)
indices_c = indices_c.reshape(-1, h * w)
result_r = torch.sum((h - 1) * input * indices_r, dim=-1)
result_c = torch.sum((w - 1) * input * indices_c, dim=-1)
result = torch.stack([result_r, result_c], dim=-1)
return result
@staticmethod
def argmax2d(joint_logits: Tensor) -> Tensor:
"""
Computes the argmax over 2D logits.
Args:
joint_logits: Tensor [B, L, num_rows, num_cols]
Returns:
Tensor [B, L, 2] - Indices for rows and columns
"""
if joint_logits.ndim == 3:
joint_logits = joint_logits.unsqueeze(1)
B, L, _, num_cols = joint_logits.shape
flat_indices = torch.argmax(joint_logits.view(B, L, -1), dim=-1)
row_indices = flat_indices // num_cols
col_indices = flat_indices % num_cols
return torch.stack([row_indices, col_indices], dim=-1)
@staticmethod
def generate_causal_mask(size: int, device: torch.device) -> Tensor:
mask = torch.triu(
torch.ones(size, size, device=device),
diagonal=1,
)
mask[mask == 1] = float("-inf")
return mask
- Questions:
softargmax1d
andsoftargmax1d
: Would Gumbel-Softmax be a better choice for the softargmax functions?apply_penalties
: The function is quite primitive and I’m not sure if it’s an appropriate way to reduce the likelihood of non-unique assignments and to boost the likelihood of class assignments that are not the max token-wise. Would a more sophisticated approach be beneficial?
Final Notes
Thanks for taking the time to read through my lengthy and potentially unstructured post. I appreciate any feedback or suggestions you may have. Sorry for any confusion, and I hope I’ve provided enough context for my questions!