From 28ec6df8f0ebf96966bee61caf5a325695314b7a Mon Sep 17 00:00:00 2001
From: Héctor Ramón Jiménez <hector@hecrj.dev>
Date: Fri, 8 Nov 2024 18:07:11 +0100
Subject: Fix cross-axis compression in `layout::flex`

---
 core/src/layout/flex.rs | 53 ++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 52 insertions(+), 1 deletion(-)

diff --git a/core/src/layout/flex.rs b/core/src/layout/flex.rs
index ac80d393..2cff5bfd 100644
--- a/core/src/layout/flex.rs
+++ b/core/src/layout/flex.rs
@@ -79,6 +79,7 @@ where
     let max_cross = axis.cross(limits.max());
 
     let mut fill_main_sum = 0;
+    let mut some_fill_cross = false;
     let (mut cross, cross_compress) = match axis {
         Axis::Vertical if width == Length::Shrink => (0.0, true),
         Axis::Horizontal if height == Length::Shrink => (0.0, true),
@@ -90,6 +91,10 @@ where
     let mut nodes: Vec<Node> = Vec::with_capacity(items.len());
     nodes.resize(items.len(), Node::default());
 
+    // FIRST PASS
+    // We lay out non-fluid elements in the main axis.
+    // If we need to compress the cross axis, then we skip any of these elements
+    // that are also fluid in the cross axis.
     for (i, (child, tree)) in items.iter().zip(trees.iter_mut()).enumerate() {
         let (fill_main_factor, fill_cross_factor) = {
             let size = child.as_widget().size();
@@ -121,6 +126,41 @@ where
             nodes[i] = layout;
         } else {
             fill_main_sum += fill_main_factor;
+            some_fill_cross = some_fill_cross || fill_cross_factor != 0;
+        }
+    }
+
+    // SECOND PASS (conditional)
+    // If we must compress the cross axis and there are fluid elements in the
+    // cross axis, we lay out any of these elements that are also non-fluid in
+    // the main axis (i.e. the ones we deliberately skipped in the first pass).
+    //
+    // We use the maximum cross length obtained in the first pass as the maximum
+    // cross limit.
+    if cross_compress && some_fill_cross {
+        for (i, (child, tree)) in items.iter().zip(trees.iter_mut()).enumerate()
+        {
+            let (fill_main_factor, fill_cross_factor) = {
+                let size = child.as_widget().size();
+
+                axis.pack(size.width.fill_factor(), size.height.fill_factor())
+            };
+
+            if fill_main_factor == 0 && fill_cross_factor != 0 {
+                let (max_width, max_height) = axis.pack(available, cross);
+
+                let child_limits =
+                    Limits::new(Size::ZERO, Size::new(max_width, max_height));
+
+                let layout =
+                    child.as_widget().layout(tree, renderer, &child_limits);
+                let size = layout.size();
+
+                available -= axis.main(size);
+                cross = cross.max(axis.cross(size));
+
+                nodes[i] = layout;
+            }
         }
     }
 
@@ -135,6 +175,9 @@ where
         },
     };
 
+    // THIRD PASS
+    // We only have the elements that are fluid in the main axis left.
+    // We use the remaining space to evenly allocate space based on fill factors.
     for (i, (child, tree)) in items.iter().zip(trees).enumerate() {
         let (fill_main_factor, fill_cross_factor) = {
             let size = child.as_widget().size();
@@ -142,10 +185,16 @@ where
             axis.pack(size.width.fill_factor(), size.height.fill_factor())
         };
 
-        if fill_main_factor != 0 || (cross_compress && fill_cross_factor != 0) {
+        if fill_main_factor != 0 {
             let max_main =
                 remaining * fill_main_factor as f32 / fill_main_sum as f32;
 
+            let max_main = if max_main.is_nan() {
+                f32::INFINITY
+            } else {
+                max_main
+            };
+
             let min_main = if max_main.is_infinite() {
                 0.0
             } else {
@@ -178,6 +227,8 @@ where
     let pad = axis.pack(padding.left, padding.top);
     let mut main = pad.0;
 
+    // FOURTH PASS
+    // We align all the laid out nodes in the cross axis, if needed.
     for (i, node) in nodes.iter_mut().enumerate() {
         if i > 0 {
             main += spacing;
-- 
cgit