diff --git a/src/ebml.rs b/src/ebml.rs index 23d24d1..793ad13 100644 --- a/src/ebml.rs +++ b/src/ebml.rs @@ -1,7 +1,7 @@ use bytes::{BigEndian, ByteOrder, BufMut}; use std::error::Error as ErrorTrait; use std::fmt::{Display, Formatter, Result as FmtResult}; -use std::io::{Cursor, Error as IoError, ErrorKind, Result as IoResult, Write}; +use std::io::{Cursor, Error as IoError, ErrorKind, Result as IoResult, Write, Seek, SeekFrom}; pub const EBML_HEAD_ID: u64 = 0x0A45DFA3; pub const VOID_ID: u64 = 0x6C; @@ -129,7 +129,7 @@ const SMALL_FLAG: u64 = 0x80; const EIGHT_FLAG: u64 = 0x01 << (8*7); const EIGHT_MAX: u64 = EIGHT_FLAG - 2; -/// Tries to write an EBML varint +/// Tries to write an EBML varint using minimal space pub fn encode_varint<T: Write>(varint: Varint, output: &mut T) -> IoResult<usize> { let (size, number) = match varint { Varint::Unknown => (1, 0xFF), @@ -157,6 +157,40 @@ pub fn encode_varint<T: Write>(varint: Varint, output: &mut T) -> IoResult<usize return output.write_all(&buffer.get_ref()[..size]).map(|()| size); } +const FOUR_FLAG: u64 = 0x10 << (8*3); +const FOUR_MAX: u64 = FOUR_FLAG - 2; + +// tries to write a varint with a fixed 4-byte representation +pub fn encode_varint_4<T: Write>(varint: Varint, output: &mut T) -> IoResult<usize> { + let number = match varint { + Varint::Unknown => FOUR_FLAG | (FOUR_FLAG - 1), + Varint::Value(too_big) if too_big > FOUR_MAX => { + return Err(IoError::new(ErrorKind::InvalidInput, WriteError::OutOfRange)) + }, + Varint::Value(value) => FOUR_FLAG | value + }; + + let mut buffer = Cursor::new([0; 4]); + buffer.put_u32::<BigEndian>(number as u32); + + return output.write_all(&buffer.get_ref()[..]).map(|()| 4); +} + +pub fn encode_element<T: Write + Seek, F: Fn(&mut T) -> IoResult<X>, X>(tag: u64, output: &mut T, content: F) -> IoResult<()> { + encode_varint(Varint::Value(tag), output)?; + encode_varint_4(Varint::Unknown, output)?; + + let start = output.seek(SeekFrom::Current(0))?; + content(output)?; + let end = output.seek(SeekFrom::Current(0))?; + + output.seek(SeekFrom::Start(start - 4))?; + encode_varint_4(Varint::Value(end - start), output)?; + output.seek(SeekFrom::Start(end))?; + + Ok(()) +} + pub fn encode_tag_header<T: Write>(tag: u64, size: Varint, output: &mut T) -> IoResult<usize> { let id_size = encode_varint(Varint::Value(tag), output)?; let size_size = encode_varint(size, output)?;