@@ 0,0 1,380 @@
+//! Common base for running compute workloads.
+
+use std::ffi::CString;
+
+use ash::version::{DeviceV1_0, EntryV1_0, InstanceV1_0};
+use ash::{vk, Device, Entry, Instance};
+
+#[derive(Debug)]
+pub enum Error {
+ LoadingError(ash::LoadingError),
+ InstanceError(ash::InstanceError),
+ VkResult(ash::vk::Result),
+ NoSuitableDevice,
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
+ write!(f, "vk error")
+ }
+}
+
+impl std::error::Error for Error {}
+
+impl From<ash::LoadingError> for Error {
+ fn from(error: ash::LoadingError) -> Error {
+ Error::LoadingError(error)
+ }
+}
+
+impl From<ash::InstanceError> for Error {
+ fn from(error: ash::InstanceError) -> Error {
+ Error::InstanceError(error)
+ }
+}
+
+impl From<vk::Result> for Error {
+ fn from(error: vk::Result) -> Error {
+ Error::VkResult(error)
+ }
+}
+
+impl From<(Vec<vk::Pipeline>, vk::Result)> for Error {
+ fn from(error: (Vec<vk::Pipeline>, vk::Result)) -> Error {
+ error.1.into()
+ }
+}
+
+/// A base for allocating resources and dispatching work.
+///
+/// This is quite similar to "device" in most GPU API's, but I didn't want to overload
+/// that term further.
+pub struct Base {
+ /// Retain the dynamic lib.
+ #[allow(unused)]
+ entry: Entry,
+
+ #[allow(unused)]
+ instance: Instance,
+
+ device: Device,
+ device_mem_props: vk::PhysicalDeviceMemoryProperties,
+ queue: vk::Queue,
+ qfi: u32,
+}
+
+/// A handle to a buffer.
+///
+/// There is no lifetime tracking at this level; the caller is responsible
+/// for destroying the buffer at the appropriate time.
+pub struct Buffer {
+ buffer: vk::Buffer,
+ buffer_memory: vk::DeviceMemory,
+}
+
+pub struct Pipeline {
+ pipeline: vk::Pipeline,
+ pipeline_layout: vk::PipelineLayout,
+ descriptor_set: vk::DescriptorSet,
+}
+
+impl Base {
+ /// Create a new instance.
+ ///
+ /// There's more to be done to make this suitable for integration with other
+ /// systems, but for now the goal is to make things simple.
+ pub fn new() -> Result<Base, Error> {
+ unsafe {
+ let app_name = CString::new("VkToy").unwrap();
+ let entry = Entry::new()?;
+ let instance = entry.create_instance(
+ &vk::InstanceCreateInfo::builder().application_info(
+ &vk::ApplicationInfo::builder()
+ .application_name(&app_name)
+ .application_version(0)
+ .engine_name(&app_name)
+ .api_version(vk_make_version!(1, 0, 0)),
+ ),
+ None,
+ )?;
+
+ let devices = instance.enumerate_physical_devices()?;
+ let (pdevice, qfi) =
+ choose_compute_device(&instance, &devices).ok_or(Error::NoSuitableDevice)?;
+
+ let device = instance.create_device(
+ pdevice,
+ &vk::DeviceCreateInfo::builder().queue_create_infos(&[
+ vk::DeviceQueueCreateInfo::builder()
+ .queue_family_index(qfi)
+ .queue_priorities(&[1.0])
+ .build(),
+ ]),
+ None,
+ )?;
+
+ let device_mem_props = instance.get_physical_device_memory_properties(pdevice);
+
+ let queue_index = 0;
+ let queue = device.get_device_queue(qfi, queue_index);
+
+ Ok(Base {
+ entry,
+ instance,
+ device,
+ device_mem_props,
+ qfi,
+ queue,
+ })
+ }
+ }
+
+ pub fn create_buffer(
+ &self,
+ size: u64,
+ mem_flags: vk::MemoryPropertyFlags,
+ ) -> Result<Buffer, Error> {
+ unsafe {
+ let buffer = self.device.create_buffer(
+ &vk::BufferCreateInfo::builder()
+ .size(size)
+ .usage(vk::BufferUsageFlags::STORAGE_BUFFER)
+ .sharing_mode(vk::SharingMode::EXCLUSIVE),
+ None,
+ )?;
+ let mem_requirements = self.device.get_buffer_memory_requirements(buffer);
+ let mem_type = find_memory_type(
+ mem_requirements.memory_type_bits,
+ mem_flags,
+ &self.device_mem_props,
+ )
+ .unwrap(); // TODO: proper error
+ let buffer_memory = self.device.allocate_memory(
+ &vk::MemoryAllocateInfo::builder()
+ .allocation_size(mem_requirements.size)
+ .memory_type_index(mem_type),
+ None,
+ )?;
+ self.device.bind_buffer_memory(buffer, buffer_memory, 0)?;
+ Ok(Buffer {
+ buffer,
+ buffer_memory,
+ })
+ }
+ }
+
+ /// This creates a pipeline that runs over the buffer.
+ ///
+ /// This will be split up into finer grains.
+ pub unsafe fn create_simple_compute_pipeline(
+ &self,
+ buffer: &Buffer,
+ ) -> Result<Pipeline, Error> {
+ let descriptor_set_layout = self.device.create_descriptor_set_layout(
+ &vk::DescriptorSetLayoutCreateInfo::builder().bindings(&[
+ vk::DescriptorSetLayoutBinding::builder()
+ .binding(0)
+ .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
+ .descriptor_count(1)
+ .stage_flags(vk::ShaderStageFlags::COMPUTE)
+ .build(),
+ ]),
+ None,
+ )?;
+
+ let descriptor_pool_sizes = [vk::DescriptorPoolSize::builder()
+ .ty(vk::DescriptorType::STORAGE_BUFFER)
+ .descriptor_count(1)
+ .build()];
+ let descriptor_pool = self.device.create_descriptor_pool(
+ &vk::DescriptorPoolCreateInfo::builder()
+ .pool_sizes(&descriptor_pool_sizes)
+ .max_sets(1),
+ None,
+ )?;
+ let descriptor_set_layouts = [descriptor_set_layout];
+ let descriptor_sets = self
+ .device
+ .allocate_descriptor_sets(
+ &vk::DescriptorSetAllocateInfo::builder()
+ .descriptor_pool(descriptor_pool)
+ .set_layouts(&descriptor_set_layouts),
+ )
+ .unwrap();
+ self.device.update_descriptor_sets(
+ &[vk::WriteDescriptorSet::builder()
+ .dst_set(descriptor_sets[0])
+ .dst_binding(0)
+ .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
+ .buffer_info(&[vk::DescriptorBufferInfo::builder()
+ .buffer(buffer.buffer)
+ .offset(0)
+ .range(1024)
+ .build()])
+ .build()],
+ &[],
+ );
+
+ // Create compute pipeline.
+ let code = include_bytes!("../comp.spv");
+ let code_u32 = convert_u32_vec(code);
+ let compute_shader_module = self
+ .device
+ .create_shader_module(&vk::ShaderModuleCreateInfo::builder().code(&code_u32), None)?;
+ let entry_name = CString::new("main").unwrap();
+ let pipeline_layout = self.device.create_pipeline_layout(
+ &vk::PipelineLayoutCreateInfo::builder().set_layouts(&descriptor_set_layouts),
+ None,
+ )?;
+
+ let pipeline = self.device.create_compute_pipelines(
+ vk::PipelineCache::null(),
+ &[vk::ComputePipelineCreateInfo::builder()
+ .stage(
+ vk::PipelineShaderStageCreateInfo::builder()
+ .stage(vk::ShaderStageFlags::COMPUTE)
+ .module(compute_shader_module)
+ .name(&entry_name)
+ .build(),
+ )
+ .layout(pipeline_layout)
+ .build()],
+ None,
+ )?[0];
+ Ok(Pipeline {
+ pipeline,
+ pipeline_layout,
+ descriptor_set: descriptor_sets[0],
+ })
+ }
+
+ pub unsafe fn run_compute_pipeline(&self, pipeline: &Pipeline) -> Result<(), Error> {
+ // Create command buffer.
+ let command_pool = self.device.create_command_pool(
+ &vk::CommandPoolCreateInfo::builder()
+ .flags(vk::CommandPoolCreateFlags::empty())
+ .queue_family_index(self.qfi),
+ None,
+ )?;
+ let command_buffer = self.device.allocate_command_buffers(
+ &vk::CommandBufferAllocateInfo::builder()
+ .command_pool(command_pool)
+ .level(vk::CommandBufferLevel::PRIMARY)
+ .command_buffer_count(1),
+ )?[0];
+
+ // Record commands into command buffer.
+ self.device.begin_command_buffer(
+ command_buffer,
+ &vk::CommandBufferBeginInfo::builder()
+ .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
+ )?;
+ self.device.cmd_bind_pipeline(
+ command_buffer,
+ vk::PipelineBindPoint::COMPUTE,
+ pipeline.pipeline,
+ );
+ self.device.cmd_bind_descriptor_sets(
+ command_buffer,
+ vk::PipelineBindPoint::COMPUTE,
+ pipeline.pipeline_layout,
+ 0,
+ &[pipeline.descriptor_set],
+ &[],
+ );
+ self.device.cmd_dispatch(command_buffer, 256, 1, 1);
+ self.device.end_command_buffer(command_buffer)?;
+
+ // Run the command buffer.
+ let fence = self.device.create_fence(
+ &vk::FenceCreateInfo::builder().flags(vk::FenceCreateFlags::empty()),
+ None,
+ )?;
+ self.device.queue_submit(
+ self.queue,
+ &[vk::SubmitInfo::builder()
+ .command_buffers(&[command_buffer])
+ .build()],
+ fence,
+ )?;
+ self.device.wait_for_fences(&[fence], true, 1_000_000)?;
+ Ok(())
+ }
+
+ pub unsafe fn read_buffer<T: Sized>(
+ &self,
+ buffer: &Buffer,
+ result: &mut Vec<T>,
+ size: usize,
+ ) -> Result<(), Error> {
+ let buf = self.device.map_memory(
+ buffer.buffer_memory,
+ 0,
+ size as u64,
+ vk::MemoryMapFlags::empty(),
+ )?;
+ if size > result.len() {
+ result.reserve(size - result.len());
+ }
+ std::ptr::copy_nonoverlapping(buf as *const T, result.as_mut_ptr(), size);
+ result.set_len(size);
+ self.device.unmap_memory(buffer.buffer_memory);
+ Ok(())
+ }
+
+ pub unsafe fn write_buffer<T: Sized>(
+ &self,
+ buffer: &Buffer,
+ contents: &[T],
+ ) -> Result<(), Error> {
+ let buf = self.device.map_memory(
+ buffer.buffer_memory,
+ 0,
+ std::mem::size_of_val(contents) as u64,
+ vk::MemoryMapFlags::empty(),
+ )?;
+ std::ptr::copy_nonoverlapping(contents.as_ptr(), buf as *mut T, contents.len());
+ self.device.unmap_memory(buffer.buffer_memory);
+ Ok(())
+ }
+}
+
+unsafe fn choose_compute_device(
+ instance: &Instance,
+ devices: &[vk::PhysicalDevice],
+) -> Option<(vk::PhysicalDevice, u32)> {
+ for pdevice in devices {
+ let props = instance.get_physical_device_queue_family_properties(*pdevice);
+ for (ix, info) in props.iter().enumerate() {
+ if info.queue_flags.contains(vk::QueueFlags::COMPUTE) {
+ return Some((*pdevice, ix as u32));
+ }
+ }
+ }
+ None
+}
+
+fn find_memory_type(
+ memory_type_bits: u32,
+ property_flags: vk::MemoryPropertyFlags,
+ props: &vk::PhysicalDeviceMemoryProperties,
+) -> Option<u32> {
+ for i in 0..props.memory_type_count {
+ if (memory_type_bits & (1 << i)) != 0
+ && props.memory_types[i as usize]
+ .property_flags
+ .contains(property_flags)
+ {
+ return Some(i);
+ }
+ }
+ None
+}
+
+fn convert_u32_vec(src: &[u8]) -> Vec<u32> {
+ src.chunks(4).map(|chunk| {
+ let mut buf = [0; 4];
+ buf.copy_from_slice(chunk);
+ u32::from_le_bytes(buf)
+ }).collect()
+}
@@ 1,275 1,23 @@
#[macro_use]
extern crate ash;
-use std::ffi::CString;
-use std::ops::Deref;
+mod base;
-use ash::version::{DeviceV1_0, EntryV1_0, InstanceV1_0};
-use ash::{vk, Device, Entry, Instance};
+use ash::vk;
-struct ToyBase {
- pdevice: vk::PhysicalDevice,
-
- device: Device,
-
- // It would be better to separate the base from the actual logic, but
- // we'll follow the usual pattern of examples and have everything in one
- // big huge glob.
- queue: vk::Queue,
- //pipeline: vk::Pipeline,
-}
-
-unsafe fn choose_compute_device(
- instance: &Instance,
- devices: &[vk::PhysicalDevice],
-) -> Option<(vk::PhysicalDevice, usize)> {
- for pdevice in devices {
- let props = instance.get_physical_device_queue_family_properties(*pdevice);
- for (ix, info) in props.iter().enumerate() {
- if info.queue_flags.contains(vk::QueueFlags::COMPUTE) {
- return Some((*pdevice, ix));
- }
- }
- }
- None
-}
-
-fn find_memory_type(
- memory_type_bits: u32,
- property_flags: vk::MemoryPropertyFlags,
- props: &vk::PhysicalDeviceMemoryProperties,
-) -> Option<u32> {
- for i in 0..props.memory_type_count {
- if (memory_type_bits & (1 << i)) != 0
- && props.memory_types[i as usize]
- .property_flags
- .contains(property_flags)
- {
- return Some(i);
- }
- }
- None
-}
-
-impl ToyBase {
- fn new() -> ToyBase {
- unsafe {
- let app_name = CString::new("VkToy").unwrap();
- let appinfo = vk::ApplicationInfo::builder()
- .application_name(&app_name)
- .application_version(0)
- .engine_name(&app_name)
- .api_version(vk_make_version!(1, 0, 0));
- let create_info = vk::InstanceCreateInfo::builder().application_info(&appinfo);
- let entry = Entry::new().unwrap();
- let instance = entry.create_instance(&create_info, None).unwrap();
- /*
- for pdevice in instance.enumerate_physical_devices().unwrap() {
- let props = instance.get_physical_device_properties(pdevice);
- println!("props: {:#?}", props);
- }
- */
- let devices = instance.enumerate_physical_devices().unwrap();
- let (pdevice, qfi) = choose_compute_device(&instance, &devices).unwrap();
-
- let queue_info = [vk::DeviceQueueCreateInfo::builder()
- .queue_family_index(qfi as u32)
- .queue_priorities(&[1.0])
- .build()];
- let device_create_info =
- vk::DeviceCreateInfo::builder().queue_create_infos(&queue_info);
- let device = instance
- .create_device(pdevice, &device_create_info, None)
- .unwrap();
- let queue_index = 0;
- let queue = device.get_device_queue(qfi as u32, queue_index);
- let device_mem_props = instance.get_physical_device_memory_properties(pdevice);
- for i in 0..device_mem_props.memory_type_count {
- println!(
- "memory type {}: {:?}",
- i, device_mem_props.memory_types[i as usize]
- );
- }
- for i in 0..device_mem_props.memory_heap_count {
- println!(
- "memory heap {}: {:?}",
- i, device_mem_props.memory_heaps[i as usize]
- );
- }
- let buffer_create_info = vk::BufferCreateInfo::builder()
- .size(1024)
- .usage(vk::BufferUsageFlags::STORAGE_BUFFER)
- .sharing_mode(vk::SharingMode::EXCLUSIVE);
- let buffer = device.create_buffer(&buffer_create_info, None).unwrap();
- let mem_requirements = device.get_buffer_memory_requirements(buffer);
- println!(
- "memory requirements: {:?} (0x{:x})",
- mem_requirements, mem_requirements.memory_type_bits
- );
- let mem_flags =
- vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT;
- let mem_type = find_memory_type(
- mem_requirements.memory_type_bits,
- mem_flags,
- &device_mem_props,
- )
- .unwrap();
- println!("chose mem type {}", mem_type);
- let alloc_info = vk::MemoryAllocateInfo::builder()
- .allocation_size(mem_requirements.size)
- .memory_type_index(mem_type);
- let buffer_memory = device.allocate_memory(&alloc_info, None).unwrap();
- device.bind_buffer_memory(buffer, buffer_memory, 0).unwrap();
- println!("allocated {:?}", buffer_memory);
-
- // Create descriptor set.
- let descriptor_set_layout_bindings = [vk::DescriptorSetLayoutBinding::builder()
- .binding(0)
- .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
- .descriptor_count(1)
- .stage_flags(vk::ShaderStageFlags::COMPUTE)
- .build()];
- let descriptor_set_layout_create_info = vk::DescriptorSetLayoutCreateInfo::builder()
- .bindings(&descriptor_set_layout_bindings);
- let descriptor_set_layout = device
- .create_descriptor_set_layout(&descriptor_set_layout_create_info, None)
- .unwrap();
-
- let descriptor_pool_sizes = [vk::DescriptorPoolSize::builder()
- .descriptor_count(1)
- .build()];
- let descriptor_pool_create_info = vk::DescriptorPoolCreateInfo::builder()
- .pool_sizes(&descriptor_pool_sizes)
- .max_sets(1);
- let descriptor_pool = device
- .create_descriptor_pool(&descriptor_pool_create_info, None)
- .unwrap();
- let descriptor_set_layouts = [descriptor_set_layout];
- let descriptor_set_allocate_info = vk::DescriptorSetAllocateInfo::builder()
- .descriptor_pool(descriptor_pool)
- .set_layouts(&descriptor_set_layouts);
- let descriptor_sets = device
- .allocate_descriptor_sets(&descriptor_set_allocate_info)
- .unwrap();
- let descriptor_buffer_info = [vk::DescriptorBufferInfo::builder()
- .buffer(buffer)
- .offset(0)
- .range(1024)
- .build()];
- let write_descriptor_set = vk::WriteDescriptorSet::builder()
- .dst_set(descriptor_sets[0])
- .dst_binding(0)
- .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
- .buffer_info(&descriptor_buffer_info)
- .build();
- device.update_descriptor_sets(&[write_descriptor_set], &[]);
-
- // Create compute pipeline.
- let code = include_bytes!("../comp.spv");
- let code_u32 = std::slice::from_raw_parts(code.as_ptr() as *const u32, code.len() / 4);
- let shader_module_create_info = vk::ShaderModuleCreateInfo::builder().code(code_u32);
- let compute_shader_module = device
- .create_shader_module(&shader_module_create_info, None)
- .unwrap();
- let entry_name = CString::new("main").unwrap();
- let shader_stage_create_info = vk::PipelineShaderStageCreateInfo::builder()
- .stage(vk::ShaderStageFlags::COMPUTE)
- .module(compute_shader_module)
- .name(&entry_name);
- let pipeline_layout_create_info =
- vk::PipelineLayoutCreateInfo::builder().set_layouts(&descriptor_set_layouts);
- let pipeline_layout = device
- .create_pipeline_layout(&pipeline_layout_create_info, None)
- .unwrap();
-
- let pipeline_create_infos = [vk::ComputePipelineCreateInfo::builder()
- .stage(shader_stage_create_info.build())
- .layout(pipeline_layout)
- .build()];
- let pipeline = device
- .create_compute_pipelines(vk::PipelineCache::null(), &pipeline_create_infos, None)
- .unwrap()[0];
-
- // Create command buffer.
- let command_pool = device
- .create_command_pool(
- &vk::CommandPoolCreateInfo::builder()
- .flags(vk::CommandPoolCreateFlags::empty())
- .queue_family_index(qfi as u32),
- None,
- )
- .unwrap();
- let command_buffer = device
- .allocate_command_buffers(
- &vk::CommandBufferAllocateInfo::builder()
- .command_pool(command_pool)
- .level(vk::CommandBufferLevel::PRIMARY)
- .command_buffer_count(1),
- )
- .unwrap()[0];
-
- // Record commands into command buffer.
- device
- .begin_command_buffer(
- command_buffer,
- &vk::CommandBufferBeginInfo::builder()
- .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
- )
- .unwrap();
- device.cmd_bind_pipeline(command_buffer, vk::PipelineBindPoint::COMPUTE, pipeline);
- device.cmd_bind_descriptor_sets(
- command_buffer,
- vk::PipelineBindPoint::COMPUTE,
- pipeline_layout,
- 0,
- &descriptor_sets,
- &[],
- );
- device.cmd_dispatch(command_buffer, 256, 1, 1);
- device.end_command_buffer(command_buffer).unwrap();
-
- // Initialize the buffer.
- let buf = device.map_memory(buffer_memory, 0, 1024, vk::MemoryMapFlags::empty()).unwrap() as *mut u32;
- for i in 0..256 {
- buf.add(i).write((i + 1) as u32);
- }
- device.unmap_memory(buffer_memory);
-
- // Run the command buffer.
- let fence = device
- .create_fence(
- &vk::FenceCreateInfo::builder().flags(vk::FenceCreateFlags::empty()),
- None,
- )
- .unwrap();
- device
- .queue_submit(
- queue,
- &[vk::SubmitInfo::builder()
- .command_buffers(&[command_buffer])
- .build()],
- fence,
- )
- .unwrap();
- device.wait_for_fences(&[fence], true, 1_000_000).unwrap();
-
- // Reap the results
- let buf = device.map_memory(buffer_memory, 0, 1024, vk::MemoryMapFlags::empty()).unwrap() as *mut u32;
- for i in 0..16 {
- println!("{}: {}", i, buf.add(i).read());
- }
- device.unmap_memory(buffer_memory);
-
- ToyBase {
- pdevice,
- device,
- queue,
- }
+fn main() {
+ let base = base::Base::new().unwrap();
+ let mem_flags = vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT;
+ let src = (0..256).map(|x| x + 1).collect::<Vec<u32>>();
+ let buffer = base.create_buffer(std::mem::size_of_val(&src) as u64, mem_flags).unwrap();
+ unsafe {
+ base.write_buffer(&buffer, &src).unwrap();
+ let pipeline = base.create_simple_compute_pipeline(&buffer).unwrap();
+ base.run_compute_pipeline(&pipeline).unwrap();
+ let mut dst: Vec<u32> = Default::default();
+ base.read_buffer(&buffer, &mut dst, src.len()).unwrap();
+ for (i, val) in dst.iter().enumerate().take(16) {
+ println!("{}: {}", i, val);
}
}
}
-
-fn main() {
- let base = ToyBase::new();
- println!("Hello, world!");
-}