~raph/vk-toy

bca6e6b2e4de078935e383373a730d97115a380b — Raph Levien 1 year, 11 months ago
Minimal example to run compute shader

The .spv code is collatz from gfx-hal examples. This code is not well
organized and is missing metadata, but it can be a start.
5 files changed, 328 insertions(+), 0 deletions(-)

A .gitignore
A Cargo.lock
A Cargo.toml
A comp.spv
A src/main.rs
A  => .gitignore +2 -0
@@ 1,2 @@
/target
**/*.rs.bk

A  => Cargo.lock +41 -0
@@ 1,41 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "ash"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
 "shared_library 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)",
]

[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"

[[package]]
name = "libc"
version = "0.2.66"
source = "registry+https://github.com/rust-lang/crates.io-index"

[[package]]
name = "shared_library"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
 "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
 "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)",
]

[[package]]
name = "vk-toy"
version = "0.1.0"
dependencies = [
 "ash 0.29.0 (registry+https://github.com/rust-lang/crates.io-index)",
]

[metadata]
"checksum ash 0.29.0 (registry+https://github.com/rust-lang/crates.io-index)" = "003d1fb2eb12eb06d4a03dbe02eea67a9fac910fa97932ab9e3a75b96a1ea5e5"
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
"checksum libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)" = "d515b1f41455adea1313a4a2ac8a8a477634fbae63cc6100e3aebb207ce61558"
"checksum shared_library 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "5a9e7e0f2bfae24d8a5b5a66c5b257a83c7412304311512a0c054cd5e619da11"

A  => Cargo.toml +10 -0
@@ 1,10 @@
[package]
name = "vk-toy"
version = "0.1.0"
authors = ["Raph Levien <raph.levien@gmail.com>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ash = "0.29.0"

A  => comp.spv +0 -0
A  => src/main.rs +275 -0
@@ 1,275 @@
#[macro_use]
extern crate ash;

use std::ffi::CString;
use std::ops::Deref;

use ash::version::{DeviceV1_0, EntryV1_0, InstanceV1_0};
use ash::{vk, Device, Entry, Instance};

struct ToyBase {
    pdevice: vk::PhysicalDevice,

    device: Device,

    // It would be better to separate the base from the actual logic, but
    // we'll follow the usual pattern of examples and have everything in one
    // big huge glob.
    queue: vk::Queue,
    //pipeline: vk::Pipeline,
}

unsafe fn choose_compute_device(
    instance: &Instance,
    devices: &[vk::PhysicalDevice],
) -> Option<(vk::PhysicalDevice, usize)> {
    for pdevice in devices {
        let props = instance.get_physical_device_queue_family_properties(*pdevice);
        for (ix, info) in props.iter().enumerate() {
            if info.queue_flags.contains(vk::QueueFlags::COMPUTE) {
                return Some((*pdevice, ix));
            }
        }
    }
    None
}

fn find_memory_type(
    memory_type_bits: u32,
    property_flags: vk::MemoryPropertyFlags,
    props: &vk::PhysicalDeviceMemoryProperties,
) -> Option<u32> {
    for i in 0..props.memory_type_count {
        if (memory_type_bits & (1 << i)) != 0
            && props.memory_types[i as usize]
                .property_flags
                .contains(property_flags)
        {
            return Some(i);
        }
    }
    None
}

impl ToyBase {
    fn new() -> ToyBase {
        unsafe {
            let app_name = CString::new("VkToy").unwrap();
            let appinfo = vk::ApplicationInfo::builder()
                .application_name(&app_name)
                .application_version(0)
                .engine_name(&app_name)
                .api_version(vk_make_version!(1, 0, 0));
            let create_info = vk::InstanceCreateInfo::builder().application_info(&appinfo);
            let entry = Entry::new().unwrap();
            let instance = entry.create_instance(&create_info, None).unwrap();
            /*
            for pdevice in instance.enumerate_physical_devices().unwrap() {
                let props = instance.get_physical_device_properties(pdevice);
                println!("props: {:#?}", props);
            }
            */
            let devices = instance.enumerate_physical_devices().unwrap();
            let (pdevice, qfi) = choose_compute_device(&instance, &devices).unwrap();

            let queue_info = [vk::DeviceQueueCreateInfo::builder()
                .queue_family_index(qfi as u32)
                .queue_priorities(&[1.0])
                .build()];
            let device_create_info =
                vk::DeviceCreateInfo::builder().queue_create_infos(&queue_info);
            let device = instance
                .create_device(pdevice, &device_create_info, None)
                .unwrap();
            let queue_index = 0;
            let queue = device.get_device_queue(qfi as u32, queue_index);
            let device_mem_props = instance.get_physical_device_memory_properties(pdevice);
            for i in 0..device_mem_props.memory_type_count {
                println!(
                    "memory type {}: {:?}",
                    i, device_mem_props.memory_types[i as usize]
                );
            }
            for i in 0..device_mem_props.memory_heap_count {
                println!(
                    "memory heap {}: {:?}",
                    i, device_mem_props.memory_heaps[i as usize]
                );
            }
            let buffer_create_info = vk::BufferCreateInfo::builder()
                .size(1024)
                .usage(vk::BufferUsageFlags::STORAGE_BUFFER)
                .sharing_mode(vk::SharingMode::EXCLUSIVE);
            let buffer = device.create_buffer(&buffer_create_info, None).unwrap();
            let mem_requirements = device.get_buffer_memory_requirements(buffer);
            println!(
                "memory requirements: {:?} (0x{:x})",
                mem_requirements, mem_requirements.memory_type_bits
            );
            let mem_flags =
                vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT;
            let mem_type = find_memory_type(
                mem_requirements.memory_type_bits,
                mem_flags,
                &device_mem_props,
            )
            .unwrap();
            println!("chose mem type {}", mem_type);
            let alloc_info = vk::MemoryAllocateInfo::builder()
                .allocation_size(mem_requirements.size)
                .memory_type_index(mem_type);
            let buffer_memory = device.allocate_memory(&alloc_info, None).unwrap();
            device.bind_buffer_memory(buffer, buffer_memory, 0).unwrap();
            println!("allocated {:?}", buffer_memory);

            // Create descriptor set.
            let descriptor_set_layout_bindings = [vk::DescriptorSetLayoutBinding::builder()
                .binding(0)
                .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                .descriptor_count(1)
                .stage_flags(vk::ShaderStageFlags::COMPUTE)
                .build()];
            let descriptor_set_layout_create_info = vk::DescriptorSetLayoutCreateInfo::builder()
                .bindings(&descriptor_set_layout_bindings);
            let descriptor_set_layout = device
                .create_descriptor_set_layout(&descriptor_set_layout_create_info, None)
                .unwrap();

            let descriptor_pool_sizes = [vk::DescriptorPoolSize::builder()
                .descriptor_count(1)
                .build()];
            let descriptor_pool_create_info = vk::DescriptorPoolCreateInfo::builder()
                .pool_sizes(&descriptor_pool_sizes)
                .max_sets(1);
            let descriptor_pool = device
                .create_descriptor_pool(&descriptor_pool_create_info, None)
                .unwrap();
            let descriptor_set_layouts = [descriptor_set_layout];
            let descriptor_set_allocate_info = vk::DescriptorSetAllocateInfo::builder()
                .descriptor_pool(descriptor_pool)
                .set_layouts(&descriptor_set_layouts);
            let descriptor_sets = device
                .allocate_descriptor_sets(&descriptor_set_allocate_info)
                .unwrap();
            let descriptor_buffer_info = [vk::DescriptorBufferInfo::builder()
                .buffer(buffer)
                .offset(0)
                .range(1024)
                .build()];
            let write_descriptor_set = vk::WriteDescriptorSet::builder()
                .dst_set(descriptor_sets[0])
                .dst_binding(0)
                .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                .buffer_info(&descriptor_buffer_info)
                .build();
            device.update_descriptor_sets(&[write_descriptor_set], &[]);

            // Create compute pipeline.
            let code = include_bytes!("../comp.spv");
            let code_u32 = std::slice::from_raw_parts(code.as_ptr() as *const u32, code.len() / 4);
            let shader_module_create_info = vk::ShaderModuleCreateInfo::builder().code(code_u32);
            let compute_shader_module = device
                .create_shader_module(&shader_module_create_info, None)
                .unwrap();
            let entry_name = CString::new("main").unwrap();
            let shader_stage_create_info = vk::PipelineShaderStageCreateInfo::builder()
                .stage(vk::ShaderStageFlags::COMPUTE)
                .module(compute_shader_module)
                .name(&entry_name);
            let pipeline_layout_create_info =
                vk::PipelineLayoutCreateInfo::builder().set_layouts(&descriptor_set_layouts);
            let pipeline_layout = device
                .create_pipeline_layout(&pipeline_layout_create_info, None)
                .unwrap();

            let pipeline_create_infos = [vk::ComputePipelineCreateInfo::builder()
                .stage(shader_stage_create_info.build())
                .layout(pipeline_layout)
                .build()];
            let pipeline = device
                .create_compute_pipelines(vk::PipelineCache::null(), &pipeline_create_infos, None)
                .unwrap()[0];

            // Create command buffer.
            let command_pool = device
                .create_command_pool(
                    &vk::CommandPoolCreateInfo::builder()
                        .flags(vk::CommandPoolCreateFlags::empty())
                        .queue_family_index(qfi as u32),
                    None,
                )
                .unwrap();
            let command_buffer = device
                .allocate_command_buffers(
                    &vk::CommandBufferAllocateInfo::builder()
                        .command_pool(command_pool)
                        .level(vk::CommandBufferLevel::PRIMARY)
                        .command_buffer_count(1),
                )
                .unwrap()[0];

            // Record commands into command buffer.
            device
                .begin_command_buffer(
                    command_buffer,
                    &vk::CommandBufferBeginInfo::builder()
                        .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
                )
                .unwrap();
            device.cmd_bind_pipeline(command_buffer, vk::PipelineBindPoint::COMPUTE, pipeline);
            device.cmd_bind_descriptor_sets(
                command_buffer,
                vk::PipelineBindPoint::COMPUTE,
                pipeline_layout,
                0,
                &descriptor_sets,
                &[],
            );
            device.cmd_dispatch(command_buffer, 256, 1, 1);
            device.end_command_buffer(command_buffer).unwrap();

            // Initialize the buffer.
            let buf = device.map_memory(buffer_memory, 0, 1024, vk::MemoryMapFlags::empty()).unwrap() as *mut u32;
            for i in 0..256 {
                buf.add(i).write((i + 1) as u32);
            }
            device.unmap_memory(buffer_memory);

            // Run the command buffer.
            let fence = device
                .create_fence(
                    &vk::FenceCreateInfo::builder().flags(vk::FenceCreateFlags::empty()),
                    None,
                )
                .unwrap();
            device
                .queue_submit(
                    queue,
                    &[vk::SubmitInfo::builder()
                        .command_buffers(&[command_buffer])
                        .build()],
                    fence,
                )
                .unwrap();
            device.wait_for_fences(&[fence], true, 1_000_000).unwrap();

            // Reap the results
            let buf = device.map_memory(buffer_memory, 0, 1024, vk::MemoryMapFlags::empty()).unwrap() as *mut u32;
            for i in 0..16 {
                println!("{}: {}", i, buf.add(i).read());
            }
            device.unmap_memory(buffer_memory);
            
            ToyBase {
                pdevice,
                device,
                queue,
            }
        }
    }
}

fn main() {
    let base = ToyBase::new();
    println!("Hello, world!");
}