~raph/vk-toy

7dc5fe00edfbcfc61cf4cdcf4f7ae3d54f8687e2 — Raph Levien 3 years ago bca6e6b
A bit of refactoring

Start teasing apart the big blob of state, make methods for the various
stages.
2 files changed, 395 insertions(+), 267 deletions(-)

A src/base.rs
M src/main.rs
A src/base.rs => src/base.rs +380 -0
@@ 0,0 1,380 @@
//! Common base for running compute workloads.

use std::ffi::CString;

use ash::version::{DeviceV1_0, EntryV1_0, InstanceV1_0};
use ash::{vk, Device, Entry, Instance};

#[derive(Debug)]
pub enum Error {
    LoadingError(ash::LoadingError),
    InstanceError(ash::InstanceError),
    VkResult(ash::vk::Result),
    NoSuitableDevice,
}

impl std::fmt::Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
        write!(f, "vk error")
    }
}

impl std::error::Error for Error {}

impl From<ash::LoadingError> for Error {
    fn from(error: ash::LoadingError) -> Error {
        Error::LoadingError(error)
    }
}

impl From<ash::InstanceError> for Error {
    fn from(error: ash::InstanceError) -> Error {
        Error::InstanceError(error)
    }
}

impl From<vk::Result> for Error {
    fn from(error: vk::Result) -> Error {
        Error::VkResult(error)
    }
}

impl From<(Vec<vk::Pipeline>, vk::Result)> for Error {
    fn from(error: (Vec<vk::Pipeline>, vk::Result)) -> Error {
        error.1.into()
    }
}

/// A base for allocating resources and dispatching work.
///
/// This is quite similar to "device" in most GPU API's, but I didn't want to overload
/// that term further.
pub struct Base {
    /// Retain the dynamic lib.
    #[allow(unused)]
    entry: Entry,

    #[allow(unused)]
    instance: Instance,

    device: Device,
    device_mem_props: vk::PhysicalDeviceMemoryProperties,
    queue: vk::Queue,
    qfi: u32,
}

/// A handle to a buffer.
///
/// There is no lifetime tracking at this level; the caller is responsible
/// for destroying the buffer at the appropriate time.
pub struct Buffer {
    buffer: vk::Buffer,
    buffer_memory: vk::DeviceMemory,
}

pub struct Pipeline {
    pipeline: vk::Pipeline,
    pipeline_layout: vk::PipelineLayout,
    descriptor_set: vk::DescriptorSet,
}

impl Base {
    /// Create a new instance.
    ///
    /// There's more to be done to make this suitable for integration with other
    /// systems, but for now the goal is to make things simple.
    pub fn new() -> Result<Base, Error> {
        unsafe {
            let app_name = CString::new("VkToy").unwrap();
            let entry = Entry::new()?;
            let instance = entry.create_instance(
                &vk::InstanceCreateInfo::builder().application_info(
                    &vk::ApplicationInfo::builder()
                        .application_name(&app_name)
                        .application_version(0)
                        .engine_name(&app_name)
                        .api_version(vk_make_version!(1, 0, 0)),
                ),
                None,
            )?;

            let devices = instance.enumerate_physical_devices()?;
            let (pdevice, qfi) =
                choose_compute_device(&instance, &devices).ok_or(Error::NoSuitableDevice)?;

            let device = instance.create_device(
                pdevice,
                &vk::DeviceCreateInfo::builder().queue_create_infos(&[
                    vk::DeviceQueueCreateInfo::builder()
                        .queue_family_index(qfi)
                        .queue_priorities(&[1.0])
                        .build(),
                ]),
                None,
            )?;

            let device_mem_props = instance.get_physical_device_memory_properties(pdevice);

            let queue_index = 0;
            let queue = device.get_device_queue(qfi, queue_index);

            Ok(Base {
                entry,
                instance,
                device,
                device_mem_props,
                qfi,
                queue,
            })
        }
    }

    pub fn create_buffer(
        &self,
        size: u64,
        mem_flags: vk::MemoryPropertyFlags,
    ) -> Result<Buffer, Error> {
        unsafe {
            let buffer = self.device.create_buffer(
                &vk::BufferCreateInfo::builder()
                    .size(size)
                    .usage(vk::BufferUsageFlags::STORAGE_BUFFER)
                    .sharing_mode(vk::SharingMode::EXCLUSIVE),
                None,
            )?;
            let mem_requirements = self.device.get_buffer_memory_requirements(buffer);
            let mem_type = find_memory_type(
                mem_requirements.memory_type_bits,
                mem_flags,
                &self.device_mem_props,
            )
            .unwrap(); // TODO: proper error
            let buffer_memory = self.device.allocate_memory(
                &vk::MemoryAllocateInfo::builder()
                    .allocation_size(mem_requirements.size)
                    .memory_type_index(mem_type),
                None,
            )?;
            self.device.bind_buffer_memory(buffer, buffer_memory, 0)?;
            Ok(Buffer {
                buffer,
                buffer_memory,
            })
        }
    }

    /// This creates a pipeline that runs over the buffer.
    ///
    /// This will be split up into finer grains.
    pub unsafe fn create_simple_compute_pipeline(
        &self,
        buffer: &Buffer,
    ) -> Result<Pipeline, Error> {
        let descriptor_set_layout = self.device.create_descriptor_set_layout(
            &vk::DescriptorSetLayoutCreateInfo::builder().bindings(&[
                vk::DescriptorSetLayoutBinding::builder()
                    .binding(0)
                    .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                    .descriptor_count(1)
                    .stage_flags(vk::ShaderStageFlags::COMPUTE)
                    .build(),
            ]),
            None,
        )?;

        let descriptor_pool_sizes = [vk::DescriptorPoolSize::builder()
            .ty(vk::DescriptorType::STORAGE_BUFFER)
            .descriptor_count(1)
            .build()];
        let descriptor_pool = self.device.create_descriptor_pool(
            &vk::DescriptorPoolCreateInfo::builder()
                .pool_sizes(&descriptor_pool_sizes)
                .max_sets(1),
            None,
        )?;
        let descriptor_set_layouts = [descriptor_set_layout];
        let descriptor_sets = self
            .device
            .allocate_descriptor_sets(
                &vk::DescriptorSetAllocateInfo::builder()
                    .descriptor_pool(descriptor_pool)
                    .set_layouts(&descriptor_set_layouts),
            )
            .unwrap();
        self.device.update_descriptor_sets(
            &[vk::WriteDescriptorSet::builder()
                .dst_set(descriptor_sets[0])
                .dst_binding(0)
                .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                .buffer_info(&[vk::DescriptorBufferInfo::builder()
                    .buffer(buffer.buffer)
                    .offset(0)
                    .range(1024)
                    .build()])
                .build()],
            &[],
        );

        // Create compute pipeline.
        let code = include_bytes!("../comp.spv");
        let code_u32 = convert_u32_vec(code);
        let compute_shader_module = self
            .device
            .create_shader_module(&vk::ShaderModuleCreateInfo::builder().code(&code_u32), None)?;
        let entry_name = CString::new("main").unwrap();
        let pipeline_layout = self.device.create_pipeline_layout(
            &vk::PipelineLayoutCreateInfo::builder().set_layouts(&descriptor_set_layouts),
            None,
        )?;

        let pipeline = self.device.create_compute_pipelines(
            vk::PipelineCache::null(),
            &[vk::ComputePipelineCreateInfo::builder()
                .stage(
                    vk::PipelineShaderStageCreateInfo::builder()
                        .stage(vk::ShaderStageFlags::COMPUTE)
                        .module(compute_shader_module)
                        .name(&entry_name)
                        .build(),
                )
                .layout(pipeline_layout)
                .build()],
            None,
        )?[0];
        Ok(Pipeline {
            pipeline,
            pipeline_layout,
            descriptor_set: descriptor_sets[0],
        })
    }

    pub unsafe fn run_compute_pipeline(&self, pipeline: &Pipeline) -> Result<(), Error> {
        // Create command buffer.
        let command_pool = self.device.create_command_pool(
            &vk::CommandPoolCreateInfo::builder()
                .flags(vk::CommandPoolCreateFlags::empty())
                .queue_family_index(self.qfi),
            None,
        )?;
        let command_buffer = self.device.allocate_command_buffers(
            &vk::CommandBufferAllocateInfo::builder()
                .command_pool(command_pool)
                .level(vk::CommandBufferLevel::PRIMARY)
                .command_buffer_count(1),
        )?[0];

        // Record commands into command buffer.
        self.device.begin_command_buffer(
            command_buffer,
            &vk::CommandBufferBeginInfo::builder()
                .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
        )?;
        self.device.cmd_bind_pipeline(
            command_buffer,
            vk::PipelineBindPoint::COMPUTE,
            pipeline.pipeline,
        );
        self.device.cmd_bind_descriptor_sets(
            command_buffer,
            vk::PipelineBindPoint::COMPUTE,
            pipeline.pipeline_layout,
            0,
            &[pipeline.descriptor_set],
            &[],
        );
        self.device.cmd_dispatch(command_buffer, 256, 1, 1);
        self.device.end_command_buffer(command_buffer)?;

        // Run the command buffer.
        let fence = self.device.create_fence(
            &vk::FenceCreateInfo::builder().flags(vk::FenceCreateFlags::empty()),
            None,
        )?;
        self.device.queue_submit(
            self.queue,
            &[vk::SubmitInfo::builder()
                .command_buffers(&[command_buffer])
                .build()],
            fence,
        )?;
        self.device.wait_for_fences(&[fence], true, 1_000_000)?;
        Ok(())
    }

    pub unsafe fn read_buffer<T: Sized>(
        &self,
        buffer: &Buffer,
        result: &mut Vec<T>,
        size: usize,
    ) -> Result<(), Error> {
        let buf = self.device.map_memory(
            buffer.buffer_memory,
            0,
            size as u64,
            vk::MemoryMapFlags::empty(),
        )?;
        if size > result.len() {
            result.reserve(size - result.len());
        }
        std::ptr::copy_nonoverlapping(buf as *const T, result.as_mut_ptr(), size);
        result.set_len(size);
        self.device.unmap_memory(buffer.buffer_memory);
        Ok(())
    }

    pub unsafe fn write_buffer<T: Sized>(
        &self,
        buffer: &Buffer,
        contents: &[T],
    ) -> Result<(), Error> {
        let buf = self.device.map_memory(
            buffer.buffer_memory,
            0,
            std::mem::size_of_val(contents) as u64,
            vk::MemoryMapFlags::empty(),
        )?;
        std::ptr::copy_nonoverlapping(contents.as_ptr(), buf as *mut T, contents.len());
        self.device.unmap_memory(buffer.buffer_memory);
        Ok(())
    }
}

unsafe fn choose_compute_device(
    instance: &Instance,
    devices: &[vk::PhysicalDevice],
) -> Option<(vk::PhysicalDevice, u32)> {
    for pdevice in devices {
        let props = instance.get_physical_device_queue_family_properties(*pdevice);
        for (ix, info) in props.iter().enumerate() {
            if info.queue_flags.contains(vk::QueueFlags::COMPUTE) {
                return Some((*pdevice, ix as u32));
            }
        }
    }
    None
}

fn find_memory_type(
    memory_type_bits: u32,
    property_flags: vk::MemoryPropertyFlags,
    props: &vk::PhysicalDeviceMemoryProperties,
) -> Option<u32> {
    for i in 0..props.memory_type_count {
        if (memory_type_bits & (1 << i)) != 0
            && props.memory_types[i as usize]
                .property_flags
                .contains(property_flags)
        {
            return Some(i);
        }
    }
    None
}

fn convert_u32_vec(src: &[u8]) -> Vec<u32> {
    src.chunks(4).map(|chunk| {
        let mut buf = [0; 4];
        buf.copy_from_slice(chunk);
        u32::from_le_bytes(buf)
    }).collect()
}

M src/main.rs => src/main.rs +15 -267
@@ 1,275 1,23 @@
#[macro_use]
extern crate ash;

use std::ffi::CString;
use std::ops::Deref;
mod base;

use ash::version::{DeviceV1_0, EntryV1_0, InstanceV1_0};
use ash::{vk, Device, Entry, Instance};
use ash::vk;

struct ToyBase {
    pdevice: vk::PhysicalDevice,

    device: Device,

    // It would be better to separate the base from the actual logic, but
    // we'll follow the usual pattern of examples and have everything in one
    // big huge glob.
    queue: vk::Queue,
    //pipeline: vk::Pipeline,
}

unsafe fn choose_compute_device(
    instance: &Instance,
    devices: &[vk::PhysicalDevice],
) -> Option<(vk::PhysicalDevice, usize)> {
    for pdevice in devices {
        let props = instance.get_physical_device_queue_family_properties(*pdevice);
        for (ix, info) in props.iter().enumerate() {
            if info.queue_flags.contains(vk::QueueFlags::COMPUTE) {
                return Some((*pdevice, ix));
            }
        }
    }
    None
}

fn find_memory_type(
    memory_type_bits: u32,
    property_flags: vk::MemoryPropertyFlags,
    props: &vk::PhysicalDeviceMemoryProperties,
) -> Option<u32> {
    for i in 0..props.memory_type_count {
        if (memory_type_bits & (1 << i)) != 0
            && props.memory_types[i as usize]
                .property_flags
                .contains(property_flags)
        {
            return Some(i);
        }
    }
    None
}

impl ToyBase {
    fn new() -> ToyBase {
        unsafe {
            let app_name = CString::new("VkToy").unwrap();
            let appinfo = vk::ApplicationInfo::builder()
                .application_name(&app_name)
                .application_version(0)
                .engine_name(&app_name)
                .api_version(vk_make_version!(1, 0, 0));
            let create_info = vk::InstanceCreateInfo::builder().application_info(&appinfo);
            let entry = Entry::new().unwrap();
            let instance = entry.create_instance(&create_info, None).unwrap();
            /*
            for pdevice in instance.enumerate_physical_devices().unwrap() {
                let props = instance.get_physical_device_properties(pdevice);
                println!("props: {:#?}", props);
            }
            */
            let devices = instance.enumerate_physical_devices().unwrap();
            let (pdevice, qfi) = choose_compute_device(&instance, &devices).unwrap();

            let queue_info = [vk::DeviceQueueCreateInfo::builder()
                .queue_family_index(qfi as u32)
                .queue_priorities(&[1.0])
                .build()];
            let device_create_info =
                vk::DeviceCreateInfo::builder().queue_create_infos(&queue_info);
            let device = instance
                .create_device(pdevice, &device_create_info, None)
                .unwrap();
            let queue_index = 0;
            let queue = device.get_device_queue(qfi as u32, queue_index);
            let device_mem_props = instance.get_physical_device_memory_properties(pdevice);
            for i in 0..device_mem_props.memory_type_count {
                println!(
                    "memory type {}: {:?}",
                    i, device_mem_props.memory_types[i as usize]
                );
            }
            for i in 0..device_mem_props.memory_heap_count {
                println!(
                    "memory heap {}: {:?}",
                    i, device_mem_props.memory_heaps[i as usize]
                );
            }
            let buffer_create_info = vk::BufferCreateInfo::builder()
                .size(1024)
                .usage(vk::BufferUsageFlags::STORAGE_BUFFER)
                .sharing_mode(vk::SharingMode::EXCLUSIVE);
            let buffer = device.create_buffer(&buffer_create_info, None).unwrap();
            let mem_requirements = device.get_buffer_memory_requirements(buffer);
            println!(
                "memory requirements: {:?} (0x{:x})",
                mem_requirements, mem_requirements.memory_type_bits
            );
            let mem_flags =
                vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT;
            let mem_type = find_memory_type(
                mem_requirements.memory_type_bits,
                mem_flags,
                &device_mem_props,
            )
            .unwrap();
            println!("chose mem type {}", mem_type);
            let alloc_info = vk::MemoryAllocateInfo::builder()
                .allocation_size(mem_requirements.size)
                .memory_type_index(mem_type);
            let buffer_memory = device.allocate_memory(&alloc_info, None).unwrap();
            device.bind_buffer_memory(buffer, buffer_memory, 0).unwrap();
            println!("allocated {:?}", buffer_memory);

            // Create descriptor set.
            let descriptor_set_layout_bindings = [vk::DescriptorSetLayoutBinding::builder()
                .binding(0)
                .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                .descriptor_count(1)
                .stage_flags(vk::ShaderStageFlags::COMPUTE)
                .build()];
            let descriptor_set_layout_create_info = vk::DescriptorSetLayoutCreateInfo::builder()
                .bindings(&descriptor_set_layout_bindings);
            let descriptor_set_layout = device
                .create_descriptor_set_layout(&descriptor_set_layout_create_info, None)
                .unwrap();

            let descriptor_pool_sizes = [vk::DescriptorPoolSize::builder()
                .descriptor_count(1)
                .build()];
            let descriptor_pool_create_info = vk::DescriptorPoolCreateInfo::builder()
                .pool_sizes(&descriptor_pool_sizes)
                .max_sets(1);
            let descriptor_pool = device
                .create_descriptor_pool(&descriptor_pool_create_info, None)
                .unwrap();
            let descriptor_set_layouts = [descriptor_set_layout];
            let descriptor_set_allocate_info = vk::DescriptorSetAllocateInfo::builder()
                .descriptor_pool(descriptor_pool)
                .set_layouts(&descriptor_set_layouts);
            let descriptor_sets = device
                .allocate_descriptor_sets(&descriptor_set_allocate_info)
                .unwrap();
            let descriptor_buffer_info = [vk::DescriptorBufferInfo::builder()
                .buffer(buffer)
                .offset(0)
                .range(1024)
                .build()];
            let write_descriptor_set = vk::WriteDescriptorSet::builder()
                .dst_set(descriptor_sets[0])
                .dst_binding(0)
                .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                .buffer_info(&descriptor_buffer_info)
                .build();
            device.update_descriptor_sets(&[write_descriptor_set], &[]);

            // Create compute pipeline.
            let code = include_bytes!("../comp.spv");
            let code_u32 = std::slice::from_raw_parts(code.as_ptr() as *const u32, code.len() / 4);
            let shader_module_create_info = vk::ShaderModuleCreateInfo::builder().code(code_u32);
            let compute_shader_module = device
                .create_shader_module(&shader_module_create_info, None)
                .unwrap();
            let entry_name = CString::new("main").unwrap();
            let shader_stage_create_info = vk::PipelineShaderStageCreateInfo::builder()
                .stage(vk::ShaderStageFlags::COMPUTE)
                .module(compute_shader_module)
                .name(&entry_name);
            let pipeline_layout_create_info =
                vk::PipelineLayoutCreateInfo::builder().set_layouts(&descriptor_set_layouts);
            let pipeline_layout = device
                .create_pipeline_layout(&pipeline_layout_create_info, None)
                .unwrap();

            let pipeline_create_infos = [vk::ComputePipelineCreateInfo::builder()
                .stage(shader_stage_create_info.build())
                .layout(pipeline_layout)
                .build()];
            let pipeline = device
                .create_compute_pipelines(vk::PipelineCache::null(), &pipeline_create_infos, None)
                .unwrap()[0];

            // Create command buffer.
            let command_pool = device
                .create_command_pool(
                    &vk::CommandPoolCreateInfo::builder()
                        .flags(vk::CommandPoolCreateFlags::empty())
                        .queue_family_index(qfi as u32),
                    None,
                )
                .unwrap();
            let command_buffer = device
                .allocate_command_buffers(
                    &vk::CommandBufferAllocateInfo::builder()
                        .command_pool(command_pool)
                        .level(vk::CommandBufferLevel::PRIMARY)
                        .command_buffer_count(1),
                )
                .unwrap()[0];

            // Record commands into command buffer.
            device
                .begin_command_buffer(
                    command_buffer,
                    &vk::CommandBufferBeginInfo::builder()
                        .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
                )
                .unwrap();
            device.cmd_bind_pipeline(command_buffer, vk::PipelineBindPoint::COMPUTE, pipeline);
            device.cmd_bind_descriptor_sets(
                command_buffer,
                vk::PipelineBindPoint::COMPUTE,
                pipeline_layout,
                0,
                &descriptor_sets,
                &[],
            );
            device.cmd_dispatch(command_buffer, 256, 1, 1);
            device.end_command_buffer(command_buffer).unwrap();

            // Initialize the buffer.
            let buf = device.map_memory(buffer_memory, 0, 1024, vk::MemoryMapFlags::empty()).unwrap() as *mut u32;
            for i in 0..256 {
                buf.add(i).write((i + 1) as u32);
            }
            device.unmap_memory(buffer_memory);

            // Run the command buffer.
            let fence = device
                .create_fence(
                    &vk::FenceCreateInfo::builder().flags(vk::FenceCreateFlags::empty()),
                    None,
                )
                .unwrap();
            device
                .queue_submit(
                    queue,
                    &[vk::SubmitInfo::builder()
                        .command_buffers(&[command_buffer])
                        .build()],
                    fence,
                )
                .unwrap();
            device.wait_for_fences(&[fence], true, 1_000_000).unwrap();

            // Reap the results
            let buf = device.map_memory(buffer_memory, 0, 1024, vk::MemoryMapFlags::empty()).unwrap() as *mut u32;
            for i in 0..16 {
                println!("{}: {}", i, buf.add(i).read());
            }
            device.unmap_memory(buffer_memory);
            
            ToyBase {
                pdevice,
                device,
                queue,
            }
fn main() {
    let base = base::Base::new().unwrap();
    let mem_flags = vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT;
    let src = (0..256).map(|x| x + 1).collect::<Vec<u32>>();
    let buffer = base.create_buffer(std::mem::size_of_val(&src) as u64, mem_flags).unwrap();
    unsafe {
        base.write_buffer(&buffer, &src).unwrap();
        let pipeline = base.create_simple_compute_pipeline(&buffer).unwrap();
        base.run_compute_pipeline(&pipeline).unwrap();
        let mut dst: Vec<u32> = Default::default();
        base.read_buffer(&buffer, &mut dst, src.len()).unwrap();
        for (i, val) in dst.iter().enumerate().take(16) {
            println!("{}: {}", i, val);
        }
    }
}

fn main() {
    let base = ToyBase::new();
    println!("Hello, world!");
}