~raph/vk-toy

5de59b9a69f96243dfb535f14bd5e1eccc15ac34 — Raph Levien 1 year, 10 months ago 7dc5fe0 master
Improve structure and generality

Make the gpu structures more generic, towards being able to run diverse
workloads.
3 files changed, 233 insertions(+), 135 deletions(-)

M src/base.rs
A src/error.rs
M src/main.rs
M src/base.rs => src/base.rs +179 -131
@@ 1,49 1,12 @@
//! Common base for running compute workloads.

use std::ffi::CString;
use std::sync::Arc;

use ash::version::{DeviceV1_0, EntryV1_0, InstanceV1_0};
use ash::{vk, Device, Entry, Instance};

#[derive(Debug)]
pub enum Error {
    LoadingError(ash::LoadingError),
    InstanceError(ash::InstanceError),
    VkResult(ash::vk::Result),
    NoSuitableDevice,
}

impl std::fmt::Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
        write!(f, "vk error")
    }
}

impl std::error::Error for Error {}

impl From<ash::LoadingError> for Error {
    fn from(error: ash::LoadingError) -> Error {
        Error::LoadingError(error)
    }
}

impl From<ash::InstanceError> for Error {
    fn from(error: ash::InstanceError) -> Error {
        Error::InstanceError(error)
    }
}

impl From<vk::Result> for Error {
    fn from(error: vk::Result) -> Error {
        Error::VkResult(error)
    }
}

impl From<(Vec<vk::Pipeline>, vk::Result)> for Error {
    fn from(error: (Vec<vk::Pipeline>, vk::Result)) -> Error {
        error.1.into()
    }
}
use crate::error::Error;

/// A base for allocating resources and dispatching work.
///


@@ 57,12 20,16 @@ pub struct Base {
    #[allow(unused)]
    instance: Instance,

    device: Device,
    device: Arc<RawDevice>,
    device_mem_props: vk::PhysicalDeviceMemoryProperties,
    queue: vk::Queue,
    qfi: u32,
}

struct RawDevice {
    device: Device,
}

/// A handle to a buffer.
///
/// There is no lifetime tracking at this level; the caller is responsible


@@ 70,14 37,24 @@ pub struct Base {
pub struct Buffer {
    buffer: vk::Buffer,
    buffer_memory: vk::DeviceMemory,
    size: u64,
}

pub struct Pipeline {
    pipeline: vk::Pipeline,
    descriptor_set_layout: vk::DescriptorSetLayout,
    pipeline_layout: vk::PipelineLayout,
}

pub struct DescriptorSet {
    descriptor_set: vk::DescriptorSet,
}

pub struct CmdBuf {
    cmd_buf: vk::CommandBuffer,
    device: Arc<RawDevice>,
}

impl Base {
    /// Create a new instance.
    ///


@@ 118,6 95,8 @@ impl Base {
            let queue_index = 0;
            let queue = device.get_device_queue(qfi, queue_index);

            let device = Arc::new(RawDevice { device });

            Ok(Base {
                entry,
                instance,


@@ 135,99 114,71 @@ impl Base {
        mem_flags: vk::MemoryPropertyFlags,
    ) -> Result<Buffer, Error> {
        unsafe {
            let buffer = self.device.create_buffer(
            let device = &self.device.device;
            let buffer = device.create_buffer(
                &vk::BufferCreateInfo::builder()
                    .size(size)
                    .usage(vk::BufferUsageFlags::STORAGE_BUFFER)
                    .sharing_mode(vk::SharingMode::EXCLUSIVE),
                None,
            )?;
            let mem_requirements = self.device.get_buffer_memory_requirements(buffer);
            let mem_requirements = device.get_buffer_memory_requirements(buffer);
            let mem_type = find_memory_type(
                mem_requirements.memory_type_bits,
                mem_flags,
                &self.device_mem_props,
            )
            .unwrap(); // TODO: proper error
            let buffer_memory = self.device.allocate_memory(
            let buffer_memory = device.allocate_memory(
                &vk::MemoryAllocateInfo::builder()
                    .allocation_size(mem_requirements.size)
                    .memory_type_index(mem_type),
                None,
            )?;
            self.device.bind_buffer_memory(buffer, buffer_memory, 0)?;
            device.bind_buffer_memory(buffer, buffer_memory, 0)?;
            Ok(Buffer {
                buffer,
                buffer_memory,
                size,
            })
        }
    }

    /// This creates a pipeline that runs over the buffer.
    ///
    /// This will be split up into finer grains.
    /// The code is included from "../comp.spv", and the descriptor set layout is just some
    /// number of buffers.
    pub unsafe fn create_simple_compute_pipeline(
        &self,
        buffer: &Buffer,
        code: &[u8],
        n_buffers: u32,
    ) -> Result<Pipeline, Error> {
        let descriptor_set_layout = self.device.create_descriptor_set_layout(
        let device = &self.device.device;
        let descriptor_set_layout = device.create_descriptor_set_layout(
            &vk::DescriptorSetLayoutCreateInfo::builder().bindings(&[
                vk::DescriptorSetLayoutBinding::builder()
                    .binding(0)
                    .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                    .descriptor_count(1)
                    .descriptor_count(n_buffers)
                    .stage_flags(vk::ShaderStageFlags::COMPUTE)
                    .build(),
            ]),
            None,
        )?;

        let descriptor_pool_sizes = [vk::DescriptorPoolSize::builder()
            .ty(vk::DescriptorType::STORAGE_BUFFER)
            .descriptor_count(1)
            .build()];
        let descriptor_pool = self.device.create_descriptor_pool(
            &vk::DescriptorPoolCreateInfo::builder()
                .pool_sizes(&descriptor_pool_sizes)
                .max_sets(1),
            None,
        )?;
        let descriptor_set_layouts = [descriptor_set_layout];
        let descriptor_sets = self
            .device
            .allocate_descriptor_sets(
                &vk::DescriptorSetAllocateInfo::builder()
                    .descriptor_pool(descriptor_pool)
                    .set_layouts(&descriptor_set_layouts),
            )
            .unwrap();
        self.device.update_descriptor_sets(
            &[vk::WriteDescriptorSet::builder()
                .dst_set(descriptor_sets[0])
                .dst_binding(0)
                .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                .buffer_info(&[vk::DescriptorBufferInfo::builder()
                    .buffer(buffer.buffer)
                    .offset(0)
                    .range(1024)
                    .build()])
                .build()],
            &[],
        );

        // Create compute pipeline.
        let code = include_bytes!("../comp.spv");
        let code_u32 = convert_u32_vec(code);
        let compute_shader_module = self
            .device
        let compute_shader_module = device
            .create_shader_module(&vk::ShaderModuleCreateInfo::builder().code(&code_u32), None)?;
        let entry_name = CString::new("main").unwrap();
        let pipeline_layout = self.device.create_pipeline_layout(
        let pipeline_layout = device.create_pipeline_layout(
            &vk::PipelineLayoutCreateInfo::builder().set_layouts(&descriptor_set_layouts),
            None,
        )?;

        let pipeline = self.device.create_compute_pipelines(
        let pipeline = device.create_compute_pipelines(
            vk::PipelineCache::null(),
            &[vk::ComputePipelineCreateInfo::builder()
                .stage(


@@ 244,60 195,100 @@ impl Base {
        Ok(Pipeline {
            pipeline,
            pipeline_layout,
            descriptor_set: descriptor_sets[0],
            descriptor_set_layout,
        })
    }

    pub unsafe fn run_compute_pipeline(&self, pipeline: &Pipeline) -> Result<(), Error> {
        // Create command buffer.
        let command_pool = self.device.create_command_pool(
            &vk::CommandPoolCreateInfo::builder()
                .flags(vk::CommandPoolCreateFlags::empty())
                .queue_family_index(self.qfi),
    pub unsafe fn create_descriptor_set(
        &self,
        pipeline: &Pipeline,
        bufs: &[&Buffer],
    ) -> Result<DescriptorSet, Error> {
        let device = &self.device.device;
        let descriptor_pool_sizes = [vk::DescriptorPoolSize::builder()
            .ty(vk::DescriptorType::STORAGE_BUFFER)
            .descriptor_count(bufs.len() as u32)
            .build()];
        let descriptor_pool = device.create_descriptor_pool(
            &vk::DescriptorPoolCreateInfo::builder()
                .pool_sizes(&descriptor_pool_sizes)
                .max_sets(1),
            None,
        )?;
        let command_buffer = self.device.allocate_command_buffers(
            &vk::CommandBufferAllocateInfo::builder()
                .command_pool(command_pool)
                .level(vk::CommandBufferLevel::PRIMARY)
                .command_buffer_count(1),
        )?[0];

        // Record commands into command buffer.
        self.device.begin_command_buffer(
            command_buffer,
            &vk::CommandBufferBeginInfo::builder()
                .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
        )?;
        self.device.cmd_bind_pipeline(
            command_buffer,
            vk::PipelineBindPoint::COMPUTE,
            pipeline.pipeline,
        );
        self.device.cmd_bind_descriptor_sets(
            command_buffer,
            vk::PipelineBindPoint::COMPUTE,
            pipeline.pipeline_layout,
            0,
            &[pipeline.descriptor_set],
        let descriptor_set_layouts = [pipeline.descriptor_set_layout];
        let descriptor_sets = device
            .allocate_descriptor_sets(
                &vk::DescriptorSetAllocateInfo::builder()
                    .descriptor_pool(descriptor_pool)
                    .set_layouts(&descriptor_set_layouts),
            )
            .unwrap();
        let buf_infos = bufs
            .iter()
            .map(|buf| {
                vk::DescriptorBufferInfo::builder()
                    .buffer(buf.buffer)
                    .offset(0)
                    .range(vk::WHOLE_SIZE)
                    .build()
            })
            .collect::<Vec<_>>();
        device.update_descriptor_sets(
            &[vk::WriteDescriptorSet::builder()
                .dst_set(descriptor_sets[0])
                .dst_binding(0)
                .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                .buffer_info(&buf_infos)
                .build()],
            &[],
        );
        self.device.cmd_dispatch(command_buffer, 256, 1, 1);
        self.device.end_command_buffer(command_buffer)?;
        Ok(DescriptorSet {
            descriptor_set: descriptor_sets[0],
        })
    }

    pub fn create_cmd_buf(&self) -> Result<CmdBuf, Error> {
        unsafe {
            let device = &self.device.device;
            let command_pool = device.create_command_pool(
                &vk::CommandPoolCreateInfo::builder()
                    .flags(vk::CommandPoolCreateFlags::empty())
                    .queue_family_index(self.qfi),
                None,
            )?;
            let cmd_buf = device.allocate_command_buffers(
                &vk::CommandBufferAllocateInfo::builder()
                    .command_pool(command_pool)
                    .level(vk::CommandBufferLevel::PRIMARY)
                    .command_buffer_count(1),
            )?[0];
            Ok(CmdBuf {
                cmd_buf,
                device: self.device.clone(),
            })
        }
    }

    /// Run the command buffer.
    ///
    /// This version simply blocks until it's complete.
    pub unsafe fn run_cmd_buf(&self, cmd_buf: &CmdBuf) -> Result<(), Error> {
        let device = &self.device.device;

        // Run the command buffer.
        let fence = self.device.create_fence(
        let fence = device.create_fence(
            &vk::FenceCreateInfo::builder().flags(vk::FenceCreateFlags::empty()),
            None,
        )?;
        self.device.queue_submit(
        device.queue_submit(
            self.queue,
            &[vk::SubmitInfo::builder()
                .command_buffers(&[command_buffer])
                .command_buffers(&[cmd_buf.cmd_buf])
                .build()],
            fence,
        )?;
        self.device.wait_for_fences(&[fence], true, 1_000_000)?;
        device.wait_for_fences(&[fence], true, 1_000_000)?;
        device.destroy_fence(fence, None);
        Ok(())
    }



@@ 305,9 296,10 @@ impl Base {
        &self,
        buffer: &Buffer,
        result: &mut Vec<T>,
        size: usize,
    ) -> Result<(), Error> {
        let buf = self.device.map_memory(
        let device = &self.device.device;
        let size = buffer.size as usize;
        let buf = device.map_memory(
            buffer.buffer_memory,
            0,
            size as u64,


@@ 318,7 310,7 @@ impl Base {
        }
        std::ptr::copy_nonoverlapping(buf as *const T, result.as_mut_ptr(), size);
        result.set_len(size);
        self.device.unmap_memory(buffer.buffer_memory);
        device.unmap_memory(buffer.buffer_memory);
        Ok(())
    }



@@ 327,18 319,72 @@ impl Base {
        buffer: &Buffer,
        contents: &[T],
    ) -> Result<(), Error> {
        let buf = self.device.map_memory(
        let device = &self.device.device;
        let buf = device.map_memory(
            buffer.buffer_memory,
            0,
            std::mem::size_of_val(contents) as u64,
            vk::MemoryMapFlags::empty(),
        )?;
        std::ptr::copy_nonoverlapping(contents.as_ptr(), buf as *mut T, contents.len());
        self.device.unmap_memory(buffer.buffer_memory);
        device.unmap_memory(buffer.buffer_memory);
        Ok(())
    }
}

impl CmdBuf {
    pub unsafe fn begin(&mut self) {
        self.device
            .device
            .begin_command_buffer(
                self.cmd_buf,
                &vk::CommandBufferBeginInfo::builder()
                    .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
            )
            .unwrap();
    }

    pub unsafe fn finish(&mut self) {
        self.device.device.end_command_buffer(self.cmd_buf).unwrap();
    }

    pub unsafe fn dispatch(&mut self, pipeline: &Pipeline, descriptor_set: &DescriptorSet) {
        let device = &self.device.device;
        device.cmd_bind_pipeline(
            self.cmd_buf,
            vk::PipelineBindPoint::COMPUTE,
            pipeline.pipeline,
        );
        device.cmd_bind_descriptor_sets(
            self.cmd_buf,
            vk::PipelineBindPoint::COMPUTE,
            pipeline.pipeline_layout,
            0,
            &[descriptor_set.descriptor_set],
            &[],
        );
        device.cmd_dispatch(self.cmd_buf, 256, 1, 1);
    }

    /// Insert a pipeline barrier for all memory accesses.
    #[allow(unused)]
    pub unsafe fn memory_barrier(&mut self) {
        let device = &self.device.device;
        device.cmd_pipeline_barrier(
            self.cmd_buf,
            vk::PipelineStageFlags::ALL_COMMANDS,
            vk::PipelineStageFlags::ALL_COMMANDS,
            vk::DependencyFlags::empty(),
            &[vk::MemoryBarrier::builder()
                .src_access_mask(vk::AccessFlags::MEMORY_WRITE)
                .dst_access_mask(vk::AccessFlags::MEMORY_READ)
                .build()],
            &[],
            &[],
        );
    }
}

unsafe fn choose_compute_device(
    instance: &Instance,
    devices: &[vk::PhysicalDevice],


@@ 372,9 418,11 @@ fn find_memory_type(
}

fn convert_u32_vec(src: &[u8]) -> Vec<u32> {
    src.chunks(4).map(|chunk| {
        let mut buf = [0; 4];
        buf.copy_from_slice(chunk);
        u32::from_le_bytes(buf)
    }).collect()
    src.chunks(4)
        .map(|chunk| {
            let mut buf = [0; 4];
            buf.copy_from_slice(chunk);
            u32::from_le_bytes(buf)
        })
        .collect()
}

A src/error.rs => src/error.rs +41 -0
@@ 0,0 1,41 @@
use ash::vk;

#[derive(Debug)]
pub enum Error {
    LoadingError(ash::LoadingError),
    InstanceError(ash::InstanceError),
    VkResult(ash::vk::Result),
    NoSuitableDevice,
}

impl std::fmt::Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
        write!(f, "vk error")
    }
}

impl std::error::Error for Error {}

impl From<ash::LoadingError> for Error {
    fn from(error: ash::LoadingError) -> Error {
        Error::LoadingError(error)
    }
}

impl From<ash::InstanceError> for Error {
    fn from(error: ash::InstanceError) -> Error {
        Error::InstanceError(error)
    }
}

impl From<vk::Result> for Error {
    fn from(error: vk::Result) -> Error {
        Error::VkResult(error)
    }
}

impl From<(Vec<vk::Pipeline>, vk::Result)> for Error {
    fn from(error: (Vec<vk::Pipeline>, vk::Result)) -> Error {
        error.1.into()
    }
}

M src/main.rs => src/main.rs +13 -4
@@ 2,6 2,7 @@
extern crate ash;

mod base;
mod error;

use ash::vk;



@@ 9,13 10,21 @@ fn main() {
    let base = base::Base::new().unwrap();
    let mem_flags = vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT;
    let src = (0..256).map(|x| x + 1).collect::<Vec<u32>>();
    let buffer = base.create_buffer(std::mem::size_of_val(&src) as u64, mem_flags).unwrap();
    let buffer = base
        .create_buffer(std::mem::size_of_val(&src[..]) as u64, mem_flags)
        .unwrap();
    unsafe {
        base.write_buffer(&buffer, &src).unwrap();
        let pipeline = base.create_simple_compute_pipeline(&buffer).unwrap();
        base.run_compute_pipeline(&pipeline).unwrap();
        let code = include_bytes!("../comp.spv");
        let pipeline = base.create_simple_compute_pipeline(code, 1).unwrap();
        let descriptor_set = base.create_descriptor_set(&pipeline, &[&buffer]).unwrap();
        let mut cmd_buf = base.create_cmd_buf().unwrap();
        cmd_buf.begin();
        cmd_buf.dispatch(&pipeline, &descriptor_set);
        cmd_buf.finish();
        base.run_cmd_buf(&cmd_buf).unwrap();
        let mut dst: Vec<u32> = Default::default();
        base.read_buffer(&buffer, &mut dst, src.len()).unwrap();
        base.read_buffer(&buffer, &mut dst).unwrap();
        for (i, val) in dst.iter().enumerate().take(16) {
            println!("{}: {}", i, val);
        }