Refactor decoder so it now uses qoi_decode_header

This commit is contained in:
Ivan Smirnov 2021-12-01 17:01:41 +00:00
parent d9507911f7
commit fc41914a48

View file

@ -31,28 +31,16 @@ impl ReadBuf {
} }
} }
pub fn qoi_decode_impl<const N: usize>(data: &[u8]) -> Result<(Header, Vec<u8>)> pub fn qoi_decode_impl<const N: usize>(data: &[u8], n_pixels: usize) -> Result<Vec<u8>>
where where
Pixel<N>: SupportedChannels, Pixel<N>: SupportedChannels,
{ {
if data.len() < QOI_HEADER_SIZE + QOI_PADDING { if unlikely(data.len() < QOI_HEADER_SIZE + QOI_PADDING) {
return Err(Error::InputBufferTooSmall { return Err(Error::InputBufferTooSmall {
size: data.len(), size: data.len(),
required: QOI_HEADER_SIZE + QOI_PADDING, required: QOI_HEADER_SIZE + QOI_PADDING,
}); });
} }
let header = Header::from_bytes(unsafe {
// Safety: Header is a POD type and we have just checked that the data fits it
*(data.as_ptr() as *const _)
});
let n_pixels = (header.width as usize) * (header.height as usize);
if n_pixels == 0 {
return Err(Error::EmptyImage { width: header.width, height: header.height });
}
if header.magic != QOI_MAGIC {
return Err(Error::InvalidMagic { magic: header.magic });
}
let mut pixels = Vec::<Pixel<N>>::with_capacity(n_pixels); let mut pixels = Vec::<Pixel<N>>::with_capacity(n_pixels);
unsafe { unsafe {
@ -167,14 +155,37 @@ where
Vec::from_raw_parts(ptr as *mut _, n_pixels * N, n_pixels * N) Vec::from_raw_parts(ptr as *mut _, n_pixels * N, n_pixels * N)
}; };
Ok((header, bytes)) Ok(bytes)
} }
pub fn qoi_decode_to_vec(data: impl AsRef<[u8]>, channels: u8) -> Result<(Header, Vec<u8>)> { pub trait MaybeChannels {
fn maybe_channels(&self) -> Option<u8>;
}
impl MaybeChannels for u8 {
#[inline]
fn maybe_channels(&self) -> Option<u8> {
Some(*self)
}
}
impl MaybeChannels for Option<u8> {
#[inline]
fn maybe_channels(&self) -> Option<u8> {
*self
}
}
#[inline]
pub fn qoi_decode_to_vec(
data: impl AsRef<[u8]>, channels: impl MaybeChannels,
) -> Result<(Header, Vec<u8>)> {
let data = data.as_ref(); let data = data.as_ref();
let header = qoi_decode_header(data)?;
let channels = channels.maybe_channels().unwrap_or(header.channels);
match channels { match channels {
3 => qoi_decode_impl::<3>(data), 3 => Ok((header, qoi_decode_impl::<3>(data, header.n_pixels())?)),
4 => qoi_decode_impl::<4>(data), 4 => Ok((header, qoi_decode_impl::<4>(data, header.n_pixels())?)),
_ => Err(Error::InvalidChannels { channels }), _ => Err(Error::InvalidChannels { channels }),
} }
} }
@ -182,15 +193,17 @@ pub fn qoi_decode_to_vec(data: impl AsRef<[u8]>, channels: u8) -> Result<(Header
#[inline] #[inline]
pub fn qoi_decode_header(data: impl AsRef<[u8]>) -> Result<Header> { pub fn qoi_decode_header(data: impl AsRef<[u8]>) -> Result<Header> {
let data = data.as_ref(); let data = data.as_ref();
if data.len() < QOI_HEADER_SIZE { if unlikely(data.len() < QOI_HEADER_SIZE) {
return Err(Error::InputBufferTooSmall { size: data.len(), required: QOI_HEADER_SIZE }); return Err(Error::InputBufferTooSmall { size: data.len(), required: QOI_HEADER_SIZE });
} }
let header = unsafe { let mut bytes = [0_u8; QOI_HEADER_SIZE];
// Safety: we have just checked the length above bytes.copy_from_slice(&data[..QOI_HEADER_SIZE]);
Header::from_bytes(*(data.as_ptr() as *const _)) let header = Header::from_bytes(bytes);
}; if unlikely(header.magic != QOI_MAGIC) {
if header.magic != QOI_MAGIC {
return Err(Error::InvalidMagic { magic: header.magic }); return Err(Error::InvalidMagic { magic: header.magic });
} }
if unlikely(header.height == 0 || header.width == 0) {
return Err(Error::EmptyImage { width: header.width, height: header.height });
}
Ok(header) Ok(header)
} }