diff --git a/src/decode.rs b/src/decode.rs index 5cdcaad..ba02565 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1,6 +1,6 @@ use std::mem; -use crate::consts::{QOI_HEADER_SIZE, QOI_INDEX, QOI_MAGIC, QOI_PADDING}; +use crate::consts::{QOI_HEADER_SIZE, QOI_INDEX, QOI_PADDING}; use crate::error::{Error, Result}; use crate::header::Header; use crate::pixel::{Pixel, SupportedChannels}; @@ -182,6 +182,7 @@ pub fn qoi_decode_to_vec( ) -> Result<(Header, Vec)> { let data = data.as_ref(); let header = qoi_decode_header(data)?; + header.validate()?; let channels = channels.maybe_channels().unwrap_or(header.channels); match channels { 3 => Ok((header, qoi_decode_impl::<3>(data, header.n_pixels())?)), @@ -198,12 +199,5 @@ pub fn qoi_decode_header(data: impl AsRef<[u8]>) -> Result
{ } let mut bytes = [0_u8; QOI_HEADER_SIZE]; bytes.copy_from_slice(&data[..QOI_HEADER_SIZE]); - let header = Header::from_bytes(bytes); - if unlikely(header.magic != QOI_MAGIC) { - return Err(Error::InvalidMagic { magic: header.magic }); - } - if unlikely(header.height == 0 || header.width == 0) { - return Err(Error::EmptyImage { width: header.width, height: header.height }); - } - Ok(header) + Ok(Header::from_bytes(bytes)) } diff --git a/src/header.rs b/src/header.rs index c1cd526..6e464ca 100644 --- a/src/header.rs +++ b/src/header.rs @@ -1,5 +1,7 @@ use crate::colorspace::ColorSpace; use crate::consts::{QOI_HEADER_SIZE, QOI_MAGIC}; +use crate::error::{Error, Result}; +use crate::utils::unlikely; #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct Header { @@ -61,4 +63,16 @@ impl Header { pub const fn n_pixels(&self) -> usize { (self.width as usize).saturating_mul(self.height as usize) } + + #[inline] + pub fn validate(&self) -> Result<()> { + if unlikely(self.magic != QOI_MAGIC) { + return Err(Error::InvalidMagic { magic: self.magic }); + } else if unlikely(self.height == 0 || self.width == 0) { + return Err(Error::EmptyImage { width: self.width, height: self.height }); + } else if unlikely(self.channels < 3 || self.channels > 4) { + return Err(Error::InvalidChannels { channels: self.channels }); + } + Ok(()) + } }