From 35af0aa84f76daddbb6d6959f9746bd09e306278 Mon Sep 17 00:00:00 2001
From: Héctor Ramón Jiménez <hector@hecrj.dev>
Date: Sat, 30 Mar 2024 13:50:40 +0100
Subject: Fix batched writes logic in `iced_wgpu::buffer`

---
 wgpu/src/buffer.rs | 59 +++++++++++++++++++++++++++++-------------------------
 1 file changed, 32 insertions(+), 27 deletions(-)

diff --git a/wgpu/src/buffer.rs b/wgpu/src/buffer.rs
index c9d6b828..463ea24a 100644
--- a/wgpu/src/buffer.rs
+++ b/wgpu/src/buffer.rs
@@ -1,7 +1,12 @@
 use std::marker::PhantomData;
+use std::num::NonZeroU64;
 use std::ops::RangeBounds;
 
-pub const MAX_WRITE_SIZE: usize = 1024 * 100;
+pub const MAX_WRITE_SIZE: usize = 100 * 1024;
+
+#[allow(unsafe_code)]
+const MAX_WRITE_SIZE_U64: NonZeroU64 =
+    unsafe { NonZeroU64::new_unchecked(MAX_WRITE_SIZE as u64) };
 
 #[derive(Debug)]
 pub struct Buffer<T> {
@@ -70,40 +75,40 @@ impl<T: bytemuck::Pod> Buffer<T> {
         contents: &[T],
     ) -> usize {
         let bytes: &[u8] = bytemuck::cast_slice(contents);
+        let mut bytes_written = 0;
 
-        if bytes.len() <= MAX_WRITE_SIZE {
+        // Split write into multiple chunks if necessary
+        while bytes_written + MAX_WRITE_SIZE < bytes.len() {
             belt.write_buffer(
                 encoder,
                 &self.raw,
-                offset as u64,
-                (bytes.len() as u64).try_into().expect("Non-empty write"),
+                (offset + bytes_written) as u64,
+                MAX_WRITE_SIZE_U64,
                 device,
             )
-            .copy_from_slice(bytes);
-        } else {
-            let mut bytes_written = 0;
-
-            let bytes_per_chunk = (bytes.len().min(MAX_WRITE_SIZE) as u64)
-                .try_into()
-                .expect("Non-empty write");
-
-            while bytes_written < bytes.len() {
-                belt.write_buffer(
-                    encoder,
-                    &self.raw,
-                    (offset + bytes_written) as u64,
-                    bytes_per_chunk,
-                    device,
-                )
-                .copy_from_slice(
-                    &bytes[bytes_written
-                        ..bytes_written + bytes_per_chunk.get() as usize],
-                );
-
-                bytes_written += bytes_per_chunk.get() as usize;
-            }
+            .copy_from_slice(
+                &bytes[bytes_written..bytes_written + MAX_WRITE_SIZE],
+            );
+
+            bytes_written += MAX_WRITE_SIZE;
         }
 
+        // There will always be some bytes left, since the previous
+        // loop guarantees `bytes_written < bytes.len()`
+        let bytes_left = ((bytes.len() - bytes_written) as u64)
+            .try_into()
+            .expect("non-empty write");
+
+        // Write them
+        belt.write_buffer(
+            encoder,
+            &self.raw,
+            (offset + bytes_written) as u64,
+            bytes_left,
+            device,
+        )
+        .copy_from_slice(&bytes[bytes_written..]);
+
         self.offsets.push(offset as u64);
 
         bytes.len()
-- 
cgit