@@ 1,49 1,12 @@
//! Common base for running compute workloads.
use std::ffi::CString;
+use std::sync::Arc;
use ash::version::{DeviceV1_0, EntryV1_0, InstanceV1_0};
use ash::{vk, Device, Entry, Instance};
-#[derive(Debug)]
-pub enum Error {
- LoadingError(ash::LoadingError),
- InstanceError(ash::InstanceError),
- VkResult(ash::vk::Result),
- NoSuitableDevice,
-}
-
-impl std::fmt::Display for Error {
- fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
- write!(f, "vk error")
- }
-}
-
-impl std::error::Error for Error {}
-
-impl From<ash::LoadingError> for Error {
- fn from(error: ash::LoadingError) -> Error {
- Error::LoadingError(error)
- }
-}
-
-impl From<ash::InstanceError> for Error {
- fn from(error: ash::InstanceError) -> Error {
- Error::InstanceError(error)
- }
-}
-
-impl From<vk::Result> for Error {
- fn from(error: vk::Result) -> Error {
- Error::VkResult(error)
- }
-}
-
-impl From<(Vec<vk::Pipeline>, vk::Result)> for Error {
- fn from(error: (Vec<vk::Pipeline>, vk::Result)) -> Error {
- error.1.into()
- }
-}
+use crate::error::Error;
/// A base for allocating resources and dispatching work.
///
@@ 57,12 20,16 @@ pub struct Base {
#[allow(unused)]
instance: Instance,
- device: Device,
+ device: Arc<RawDevice>,
device_mem_props: vk::PhysicalDeviceMemoryProperties,
queue: vk::Queue,
qfi: u32,
}
+struct RawDevice {
+ device: Device,
+}
+
/// A handle to a buffer.
///
/// There is no lifetime tracking at this level; the caller is responsible
@@ 70,14 37,24 @@ pub struct Base {
pub struct Buffer {
buffer: vk::Buffer,
buffer_memory: vk::DeviceMemory,
+ size: u64,
}
pub struct Pipeline {
pipeline: vk::Pipeline,
+ descriptor_set_layout: vk::DescriptorSetLayout,
pipeline_layout: vk::PipelineLayout,
+}
+
+pub struct DescriptorSet {
descriptor_set: vk::DescriptorSet,
}
+pub struct CmdBuf {
+ cmd_buf: vk::CommandBuffer,
+ device: Arc<RawDevice>,
+}
+
impl Base {
/// Create a new instance.
///
@@ 118,6 95,8 @@ impl Base {
let queue_index = 0;
let queue = device.get_device_queue(qfi, queue_index);
+ let device = Arc::new(RawDevice { device });
+
Ok(Base {
entry,
instance,
@@ 135,99 114,71 @@ impl Base {
mem_flags: vk::MemoryPropertyFlags,
) -> Result<Buffer, Error> {
unsafe {
- let buffer = self.device.create_buffer(
+ let device = &self.device.device;
+ let buffer = device.create_buffer(
&vk::BufferCreateInfo::builder()
.size(size)
.usage(vk::BufferUsageFlags::STORAGE_BUFFER)
.sharing_mode(vk::SharingMode::EXCLUSIVE),
None,
)?;
- let mem_requirements = self.device.get_buffer_memory_requirements(buffer);
+ let mem_requirements = device.get_buffer_memory_requirements(buffer);
let mem_type = find_memory_type(
mem_requirements.memory_type_bits,
mem_flags,
&self.device_mem_props,
)
.unwrap(); // TODO: proper error
- let buffer_memory = self.device.allocate_memory(
+ let buffer_memory = device.allocate_memory(
&vk::MemoryAllocateInfo::builder()
.allocation_size(mem_requirements.size)
.memory_type_index(mem_type),
None,
)?;
- self.device.bind_buffer_memory(buffer, buffer_memory, 0)?;
+ device.bind_buffer_memory(buffer, buffer_memory, 0)?;
Ok(Buffer {
buffer,
buffer_memory,
+ size,
})
}
}
/// This creates a pipeline that runs over the buffer.
///
- /// This will be split up into finer grains.
+ /// The code is included from "../comp.spv", and the descriptor set layout is just some
+ /// number of buffers.
pub unsafe fn create_simple_compute_pipeline(
&self,
- buffer: &Buffer,
+ code: &[u8],
+ n_buffers: u32,
) -> Result<Pipeline, Error> {
- let descriptor_set_layout = self.device.create_descriptor_set_layout(
+ let device = &self.device.device;
+ let descriptor_set_layout = device.create_descriptor_set_layout(
&vk::DescriptorSetLayoutCreateInfo::builder().bindings(&[
vk::DescriptorSetLayoutBinding::builder()
.binding(0)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
- .descriptor_count(1)
+ .descriptor_count(n_buffers)
.stage_flags(vk::ShaderStageFlags::COMPUTE)
.build(),
]),
None,
)?;
- let descriptor_pool_sizes = [vk::DescriptorPoolSize::builder()
- .ty(vk::DescriptorType::STORAGE_BUFFER)
- .descriptor_count(1)
- .build()];
- let descriptor_pool = self.device.create_descriptor_pool(
- &vk::DescriptorPoolCreateInfo::builder()
- .pool_sizes(&descriptor_pool_sizes)
- .max_sets(1),
- None,
- )?;
let descriptor_set_layouts = [descriptor_set_layout];
- let descriptor_sets = self
- .device
- .allocate_descriptor_sets(
- &vk::DescriptorSetAllocateInfo::builder()
- .descriptor_pool(descriptor_pool)
- .set_layouts(&descriptor_set_layouts),
- )
- .unwrap();
- self.device.update_descriptor_sets(
- &[vk::WriteDescriptorSet::builder()
- .dst_set(descriptor_sets[0])
- .dst_binding(0)
- .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
- .buffer_info(&[vk::DescriptorBufferInfo::builder()
- .buffer(buffer.buffer)
- .offset(0)
- .range(1024)
- .build()])
- .build()],
- &[],
- );
// Create compute pipeline.
- let code = include_bytes!("../comp.spv");
let code_u32 = convert_u32_vec(code);
- let compute_shader_module = self
- .device
+ let compute_shader_module = device
.create_shader_module(&vk::ShaderModuleCreateInfo::builder().code(&code_u32), None)?;
let entry_name = CString::new("main").unwrap();
- let pipeline_layout = self.device.create_pipeline_layout(
+ let pipeline_layout = device.create_pipeline_layout(
&vk::PipelineLayoutCreateInfo::builder().set_layouts(&descriptor_set_layouts),
None,
)?;
- let pipeline = self.device.create_compute_pipelines(
+ let pipeline = device.create_compute_pipelines(
vk::PipelineCache::null(),
&[vk::ComputePipelineCreateInfo::builder()
.stage(
@@ 244,60 195,100 @@ impl Base {
Ok(Pipeline {
pipeline,
pipeline_layout,
- descriptor_set: descriptor_sets[0],
+ descriptor_set_layout,
})
}
- pub unsafe fn run_compute_pipeline(&self, pipeline: &Pipeline) -> Result<(), Error> {
- // Create command buffer.
- let command_pool = self.device.create_command_pool(
- &vk::CommandPoolCreateInfo::builder()
- .flags(vk::CommandPoolCreateFlags::empty())
- .queue_family_index(self.qfi),
+ pub unsafe fn create_descriptor_set(
+ &self,
+ pipeline: &Pipeline,
+ bufs: &[&Buffer],
+ ) -> Result<DescriptorSet, Error> {
+ let device = &self.device.device;
+ let descriptor_pool_sizes = [vk::DescriptorPoolSize::builder()
+ .ty(vk::DescriptorType::STORAGE_BUFFER)
+ .descriptor_count(bufs.len() as u32)
+ .build()];
+ let descriptor_pool = device.create_descriptor_pool(
+ &vk::DescriptorPoolCreateInfo::builder()
+ .pool_sizes(&descriptor_pool_sizes)
+ .max_sets(1),
None,
)?;
- let command_buffer = self.device.allocate_command_buffers(
- &vk::CommandBufferAllocateInfo::builder()
- .command_pool(command_pool)
- .level(vk::CommandBufferLevel::PRIMARY)
- .command_buffer_count(1),
- )?[0];
-
- // Record commands into command buffer.
- self.device.begin_command_buffer(
- command_buffer,
- &vk::CommandBufferBeginInfo::builder()
- .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
- )?;
- self.device.cmd_bind_pipeline(
- command_buffer,
- vk::PipelineBindPoint::COMPUTE,
- pipeline.pipeline,
- );
- self.device.cmd_bind_descriptor_sets(
- command_buffer,
- vk::PipelineBindPoint::COMPUTE,
- pipeline.pipeline_layout,
- 0,
- &[pipeline.descriptor_set],
+ let descriptor_set_layouts = [pipeline.descriptor_set_layout];
+ let descriptor_sets = device
+ .allocate_descriptor_sets(
+ &vk::DescriptorSetAllocateInfo::builder()
+ .descriptor_pool(descriptor_pool)
+ .set_layouts(&descriptor_set_layouts),
+ )
+ .unwrap();
+ let buf_infos = bufs
+ .iter()
+ .map(|buf| {
+ vk::DescriptorBufferInfo::builder()
+ .buffer(buf.buffer)
+ .offset(0)
+ .range(vk::WHOLE_SIZE)
+ .build()
+ })
+ .collect::<Vec<_>>();
+ device.update_descriptor_sets(
+ &[vk::WriteDescriptorSet::builder()
+ .dst_set(descriptor_sets[0])
+ .dst_binding(0)
+ .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
+ .buffer_info(&buf_infos)
+ .build()],
&[],
);
- self.device.cmd_dispatch(command_buffer, 256, 1, 1);
- self.device.end_command_buffer(command_buffer)?;
+ Ok(DescriptorSet {
+ descriptor_set: descriptor_sets[0],
+ })
+ }
+
+ pub fn create_cmd_buf(&self) -> Result<CmdBuf, Error> {
+ unsafe {
+ let device = &self.device.device;
+ let command_pool = device.create_command_pool(
+ &vk::CommandPoolCreateInfo::builder()
+ .flags(vk::CommandPoolCreateFlags::empty())
+ .queue_family_index(self.qfi),
+ None,
+ )?;
+ let cmd_buf = device.allocate_command_buffers(
+ &vk::CommandBufferAllocateInfo::builder()
+ .command_pool(command_pool)
+ .level(vk::CommandBufferLevel::PRIMARY)
+ .command_buffer_count(1),
+ )?[0];
+ Ok(CmdBuf {
+ cmd_buf,
+ device: self.device.clone(),
+ })
+ }
+ }
+
+ /// Run the command buffer.
+ ///
+ /// This version simply blocks until it's complete.
+ pub unsafe fn run_cmd_buf(&self, cmd_buf: &CmdBuf) -> Result<(), Error> {
+ let device = &self.device.device;
// Run the command buffer.
- let fence = self.device.create_fence(
+ let fence = device.create_fence(
&vk::FenceCreateInfo::builder().flags(vk::FenceCreateFlags::empty()),
None,
)?;
- self.device.queue_submit(
+ device.queue_submit(
self.queue,
&[vk::SubmitInfo::builder()
- .command_buffers(&[command_buffer])
+ .command_buffers(&[cmd_buf.cmd_buf])
.build()],
fence,
)?;
- self.device.wait_for_fences(&[fence], true, 1_000_000)?;
+ device.wait_for_fences(&[fence], true, 1_000_000)?;
+ device.destroy_fence(fence, None);
Ok(())
}
@@ 305,9 296,10 @@ impl Base {
&self,
buffer: &Buffer,
result: &mut Vec<T>,
- size: usize,
) -> Result<(), Error> {
- let buf = self.device.map_memory(
+ let device = &self.device.device;
+ let size = buffer.size as usize;
+ let buf = device.map_memory(
buffer.buffer_memory,
0,
size as u64,
@@ 318,7 310,7 @@ impl Base {
}
std::ptr::copy_nonoverlapping(buf as *const T, result.as_mut_ptr(), size);
result.set_len(size);
- self.device.unmap_memory(buffer.buffer_memory);
+ device.unmap_memory(buffer.buffer_memory);
Ok(())
}
@@ 327,18 319,72 @@ impl Base {
buffer: &Buffer,
contents: &[T],
) -> Result<(), Error> {
- let buf = self.device.map_memory(
+ let device = &self.device.device;
+ let buf = device.map_memory(
buffer.buffer_memory,
0,
std::mem::size_of_val(contents) as u64,
vk::MemoryMapFlags::empty(),
)?;
std::ptr::copy_nonoverlapping(contents.as_ptr(), buf as *mut T, contents.len());
- self.device.unmap_memory(buffer.buffer_memory);
+ device.unmap_memory(buffer.buffer_memory);
Ok(())
}
}
+impl CmdBuf {
+ pub unsafe fn begin(&mut self) {
+ self.device
+ .device
+ .begin_command_buffer(
+ self.cmd_buf,
+ &vk::CommandBufferBeginInfo::builder()
+ .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
+ )
+ .unwrap();
+ }
+
+ pub unsafe fn finish(&mut self) {
+ self.device.device.end_command_buffer(self.cmd_buf).unwrap();
+ }
+
+ pub unsafe fn dispatch(&mut self, pipeline: &Pipeline, descriptor_set: &DescriptorSet) {
+ let device = &self.device.device;
+ device.cmd_bind_pipeline(
+ self.cmd_buf,
+ vk::PipelineBindPoint::COMPUTE,
+ pipeline.pipeline,
+ );
+ device.cmd_bind_descriptor_sets(
+ self.cmd_buf,
+ vk::PipelineBindPoint::COMPUTE,
+ pipeline.pipeline_layout,
+ 0,
+ &[descriptor_set.descriptor_set],
+ &[],
+ );
+ device.cmd_dispatch(self.cmd_buf, 256, 1, 1);
+ }
+
+ /// Insert a pipeline barrier for all memory accesses.
+ #[allow(unused)]
+ pub unsafe fn memory_barrier(&mut self) {
+ let device = &self.device.device;
+ device.cmd_pipeline_barrier(
+ self.cmd_buf,
+ vk::PipelineStageFlags::ALL_COMMANDS,
+ vk::PipelineStageFlags::ALL_COMMANDS,
+ vk::DependencyFlags::empty(),
+ &[vk::MemoryBarrier::builder()
+ .src_access_mask(vk::AccessFlags::MEMORY_WRITE)
+ .dst_access_mask(vk::AccessFlags::MEMORY_READ)
+ .build()],
+ &[],
+ &[],
+ );
+ }
+}
+
unsafe fn choose_compute_device(
instance: &Instance,
devices: &[vk::PhysicalDevice],
@@ 372,9 418,11 @@ fn find_memory_type(
}
fn convert_u32_vec(src: &[u8]) -> Vec<u32> {
- src.chunks(4).map(|chunk| {
- let mut buf = [0; 4];
- buf.copy_from_slice(chunk);
- u32::from_le_bytes(buf)
- }).collect()
+ src.chunks(4)
+ .map(|chunk| {
+ let mut buf = [0; 4];
+ buf.copy_from_slice(chunk);
+ u32::from_le_bytes(buf)
+ })
+ .collect()
}