Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do a bunch of clean ups #17

Merged
merged 15 commits into from
Feb 1, 2024
Merged
9 changes: 4 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ version = "0.1.0"
authors = ["Tobin C. Harding <me@tobin.cc>"]
license = "CC0-1.0"
repository = "https://github.com/tcharding/rust-psbt/"
description = "Partially signed Bitcoin Transaction, v0 and v1"
description = "Partially Signed Bitcoin Transaction, v0 and v2"
categories = ["cryptography::cryptocurrencies"]
keywords = [ "crypto", "bitcoin" ]
readme = "../README.md"
readme = "README.md"
edition = "2021"
rust-version = "1.56.1"
exclude = ["tests", "contrib"]
Expand All @@ -27,11 +27,10 @@ miniscript-std = ["std", "miniscript/std"]
miniscript-no-std = ["no-std", "miniscript/no-std"]

[dependencies]
bitcoin = { version = "0.31.0", default-features = false, features = [] }
bitcoin = { version = "0.31.0", default-features = false }

# Currenty miniscript only works in with "std" enabled.
# Do not use this feature, use "miniscript-std" or "miniscript-no-std" instead.
miniscript = { version = "11.0.0", default-features = false, optional = true }

# Do NOT use this as a feature! Use the `serde` feature instead.
actual-serde = { package = "serde", version = "1.0.103", default-features = false, features = [ "derive", "alloc" ], optional = true }
# There is no reason to use this dependency directly, it is activated by the "no-std" feature.
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ use std::io;
#[cfg(not(feature = "std"))]
use core2::io;

use crate::version::Version;

#[rustfmt::skip] // Keep pubic re-exports separate
#[doc(inline)]
pub use crate::{
sighash_type::PsbtSighashType,
sighash_type::{PsbtSighashType, InvalidSighashTypeError},
version::Version,
};

/// PSBT version 0 - the original PSBT version.
Expand Down
64 changes: 62 additions & 2 deletions src/sighash_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,15 @@ impl std::error::Error for SighashTypeParseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None }
}

// TODO: Remove this error after issue resolves.
// https://github.com/rust-bitcoin/rust-bitcoin/issues/2423
/// Integer is not a consensus valid sighash type.
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum InvalidSighashTypeError {
/// TODO:
/// The real invalid sighash type error.
Bitcoin(sighash::InvalidSighashTypeError),
/// TODO:
/// Hack required because of non_exhaustive on the real error.
Invalid(u32),
}

Expand Down Expand Up @@ -149,3 +151,61 @@ impl std::error::Error for InvalidSighashTypeError {
impl From<sighash::InvalidSighashTypeError> for InvalidSighashTypeError {
fn from(e: sighash::InvalidSighashTypeError) -> Self { Self::Bitcoin(e) }
}

#[cfg(test)]
mod tests {
use core::str::FromStr;

use super::*;
use crate::sighash_type::InvalidSighashTypeError;

#[test]
fn psbt_sighash_type_ecdsa() {
for ecdsa in &[
EcdsaSighashType::All,
EcdsaSighashType::None,
EcdsaSighashType::Single,
EcdsaSighashType::AllPlusAnyoneCanPay,
EcdsaSighashType::NonePlusAnyoneCanPay,
EcdsaSighashType::SinglePlusAnyoneCanPay,
] {
let sighash = PsbtSighashType::from(*ecdsa);
let s = format!("{}", sighash);
let back = PsbtSighashType::from_str(&s).unwrap();
assert_eq!(back, sighash);
assert_eq!(back.ecdsa_hash_ty().unwrap(), *ecdsa);
}
}

#[test]
fn psbt_sighash_type_taproot() {
for tap in &[
TapSighashType::Default,
TapSighashType::All,
TapSighashType::None,
TapSighashType::Single,
TapSighashType::AllPlusAnyoneCanPay,
TapSighashType::NonePlusAnyoneCanPay,
TapSighashType::SinglePlusAnyoneCanPay,
] {
let sighash = PsbtSighashType::from(*tap);
let s = format!("{}", sighash);
let back = PsbtSighashType::from_str(&s).unwrap();
assert_eq!(back, sighash);
assert_eq!(back.taproot_hash_ty().unwrap(), *tap);
}
}

#[test]
fn psbt_sighash_type_notstd() {
let nonstd = 0xdddddddd;
let sighash = PsbtSighashType { inner: nonstd };
let s = format!("{}", sighash);
let back = PsbtSighashType::from_str(&s).unwrap();

assert_eq!(back, sighash);
// TODO: Add this assertion once we remove InvalidSighashTypeError
// assert_eq!(back.ecdsa_hash_ty(), Err(NonStandardSighashTypeError(nonstd)));
assert_eq!(back.taproot_hash_ty(), Err(InvalidSighashTypeError::Invalid(nonstd)));
}
}
15 changes: 7 additions & 8 deletions src/v0/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ use crate::v0::Psbt;
///
/// This error is returned when deserializing a complete PSBT, not for deserializing parts
/// of it or individual data types.
// TODO: This can change to `serialize::Error` if we rename `serialize::Error` to `serialize::Error`.
#[derive(Debug)]
#[non_exhaustive]
pub enum DeserializePsbtError {
pub enum DeserializeError {
/// Invalid magic bytes, expected the ASCII for "psbt" serialized in most significant byte order.
// TODO: Consider adding the invalid bytes.
InvalidMagic,
Expand All @@ -36,28 +35,28 @@ pub enum DeserializePsbtError {
UnsignedTxChecks(UnsignedTxChecksError),
}

impl fmt::Display for DeserializePsbtError {
impl fmt::Display for DeserializeError {
fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { todo!() }
}

#[cfg(feature = "std")]
impl std::error::Error for DeserializePsbtError {
impl std::error::Error for DeserializeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { todo!() }
}

impl From<global::DecodeError> for DeserializePsbtError {
impl From<global::DecodeError> for DeserializeError {
fn from(e: global::DecodeError) -> Self { Self::DecodeGlobal(e) }
}

impl From<input::DecodeError> for DeserializePsbtError {
impl From<input::DecodeError> for DeserializeError {
fn from(e: input::DecodeError) -> Self { Self::DecodeInput(e) }
}

impl From<output::DecodeError> for DeserializePsbtError {
impl From<output::DecodeError> for DeserializeError {
fn from(e: output::DecodeError) -> Self { Self::DecodeOutput(e) }
}

impl From<UnsignedTxChecksError> for DeserializePsbtError {
impl From<UnsignedTxChecksError> for DeserializeError {
fn from(e: UnsignedTxChecksError) -> Self { Self::UnsignedTxChecks(e) }
}

Expand Down
1 change: 1 addition & 0 deletions src/v0/miniscript/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use core::fmt;

use crate::bitcoin::{sighash, ScriptBuf};
use crate::miniscript::{self, descriptor, interpreter};
use crate::prelude::*;
#[cfg(doc)]
use crate::v0::Psbt;

Expand Down
15 changes: 7 additions & 8 deletions src/v0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::v0::map::{global, Map};

#[rustfmt::skip] // Keep pubic re-exports separate
pub use self::{
error::{IndexOutOfBoundsError, SignerChecksError, SignError, UnsignedTxChecksError, DeserializePsbtError},
error::{IndexOutOfBoundsError, SignerChecksError, SignError, UnsignedTxChecksError, DeserializeError},
map::{Input, Output, Global},
};

Expand Down Expand Up @@ -84,20 +84,19 @@ impl Psbt {
buf
}

// TODO: Change this to use DeserializePsbtError (although that name is shit) same as v2.
/// Deserialize a value from raw binary data.
pub fn deserialize(bytes: &[u8]) -> Result<Self, DeserializePsbtError> {
pub fn deserialize(bytes: &[u8]) -> Result<Self, DeserializeError> {
const MAGIC_BYTES: &[u8] = b"psbt";
if bytes.get(0..MAGIC_BYTES.len()) != Some(MAGIC_BYTES) {
return Err(DeserializePsbtError::InvalidMagic);
return Err(DeserializeError::InvalidMagic);
}

const PSBT_SERPARATOR: u8 = 0xff_u8;
if bytes.get(MAGIC_BYTES.len()) != Some(&PSBT_SERPARATOR) {
return Err(DeserializePsbtError::InvalidSeparator);
return Err(DeserializeError::InvalidSeparator);
}

let mut d = bytes.get(5..).ok_or(DeserializePsbtError::NoMorePairs)?;
let mut d = bytes.get(5..).ok_or(DeserializeError::NoMorePairs)?;

let global = Global::decode(&mut d)?;
global.unsigned_tx_checks()?;
Expand Down Expand Up @@ -565,7 +564,7 @@ mod display_from_str {
#[non_exhaustive]
pub enum PsbtParseError {
/// Error in internal PSBT data structure.
PsbtEncoding(DeserializePsbtError),
PsbtEncoding(DeserializeError),
/// Error in PSBT Base64 encoding.
Base64Encoding(bitcoin::base64::DecodeError),
}
Expand Down Expand Up @@ -808,7 +807,7 @@ mod tests {
use crate::{io, raw, V0};

#[track_caller]
pub fn hex_psbt(s: &str) -> Result<Psbt, DeserializePsbtError> {
pub fn hex_psbt(s: &str) -> Result<Psbt, DeserializeError> {
let r: Result<Vec<u8>, bitcoin::hex::HexToBytesError> = Vec::from_hex(s);
match r {
Err(_e) => panic!("unable to parse hex string {}", s),
Expand Down
13 changes: 6 additions & 7 deletions src/v2/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ use crate::v2::map::{global, input, output};
///
/// This error is returned when deserializing a complete PSBT, not for deserializing parts
/// of it or individual data types.
// TODO: This can change to `serialize::Error` if we rename `serialize::Error` to `serialize::Error`.
#[derive(Debug)]
#[non_exhaustive]
pub enum DeserializePsbtError {
pub enum DeserializeError {
/// Invalid magic bytes, expected the ASCII for "psbt" serialized in most significant byte order.
// TODO: Consider adding the invalid bytes.
InvalidMagic,
Expand All @@ -34,24 +33,24 @@ pub enum DeserializePsbtError {
DecodeOutput(output::DecodeError),
}

impl fmt::Display for DeserializePsbtError {
impl fmt::Display for DeserializeError {
fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { todo!() }
}

#[cfg(feature = "std")]
impl std::error::Error for DeserializePsbtError {
impl std::error::Error for DeserializeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { todo!() }
}

impl From<global::DecodeError> for DeserializePsbtError {
impl From<global::DecodeError> for DeserializeError {
fn from(e: global::DecodeError) -> Self { Self::DecodeGlobal(e) }
}

impl From<input::DecodeError> for DeserializePsbtError {
impl From<input::DecodeError> for DeserializeError {
fn from(e: input::DecodeError) -> Self { Self::DecodeInput(e) }
}

impl From<output::DecodeError> for DeserializePsbtError {
impl From<output::DecodeError> for DeserializeError {
fn from(e: output::DecodeError) -> Self { Self::DecodeOutput(e) }
}

Expand Down
51 changes: 40 additions & 11 deletions src/v2/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
//!
//! It is only possible to extract a transaction from a PSBT _after_ it has been finalized. However
//! the Extractor role may be fulfilled by a separate entity to the Finalizer hence this is a
//! separate module and does not require `rust-miniscript`.
//! separate module and does not require the "miniscript" feature be enabled.
//!
//! [BIP-174]: <https://github.com/bitcoin/bips/blob/master/bip-0174.mediawiki>

use core::fmt;

use bitcoin::{FeeRate, Transaction};
use bitcoin::{FeeRate, Transaction, Txid};

use crate::error::{write_err, FeeError};
use crate::v2::{DetermineLockTimeError, Psbt};
Expand All @@ -30,13 +30,19 @@ impl Extractor {
/// Creates an `Extractor`.
///
/// An extractor can only accept a PSBT that has been finalized.
pub fn new(psbt: Psbt) -> Result<Self, PsbtNotFinalizedError> {
pub fn new(psbt: Psbt) -> Result<Self, Error> {
if psbt.inputs.iter().any(|input| !input.is_finalized()) {
return Err(PsbtNotFinalizedError);
return Err(Error::PsbtNotFinalized);
}
let _ = psbt.determine_lock_time()?;

Ok(Self(psbt))
}

/// Returns this PSBT's unique identification.
pub fn id(&self) -> Txid {
self.0.id().expect("Extractor guarantees lock time can be determined")
}
}

impl Extractor {
Expand Down Expand Up @@ -118,19 +124,42 @@ impl Extractor {
}
}

/// Attempted to extract tx from an unfinalized PSBT.
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct PsbtNotFinalizedError;
/// Error constructing a [`Finalizer`].
#[derive(Debug)]
pub enum Error {
/// Attempted to extract tx from an unfinalized PSBT.
PsbtNotFinalized,
/// Finalizer must be able to determine the lock time.
DetermineLockTime(DetermineLockTimeError),
}

impl fmt::Display for PsbtNotFinalizedError {
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "attempted to extract tx from an unfinalized PSBT")
use Error::*;

match *self {
PsbtNotFinalized => write!(f, "attempted to extract tx from an unfinalized PSBT"),
DetermineLockTime(ref e) =>
write_err!(f, "extractor must be able to determine the lock time"; e),
}
}
}

#[cfg(feature = "std")]
impl std::error::Error for PsbtNotFinalizedError {}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use Error::*;

match *self {
DetermineLockTime(ref e) => Some(e),
PsbtNotFinalized => None,
}
}
}

impl From<DetermineLockTimeError> for Error {
fn from(e: DetermineLockTimeError) -> Self { Self::DetermineLockTime(e) }
}

/// Error caused by fee calculation when extracting a [`Transaction`] from a PSBT.
#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down
10 changes: 4 additions & 6 deletions src/v2/map/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use core::convert::TryFrom;
use core::fmt;

use bitcoin::bip32::{ChildNumber, DerivationPath, Fingerprint, KeySource, Xpub};
use bitcoin::consensus::encode::MAX_VEC_SIZE;
use bitcoin::consensus::{encode as consensus, Decodable};
use bitcoin::locktime::absolute;
use bitcoin::{bip32, transaction, Transaction, VarInt};
Expand Down Expand Up @@ -127,16 +126,15 @@ impl Global {
self.tx_modifiable_flags & OUTPUTS_MODIFIABLE > 0
}

// TODO: Use this function?
// TODO: Investigate if we should be using this function?
#[allow(dead_code)]
pub(crate) fn has_sighash_single(&self) -> bool {
self.tx_modifiable_flags & SIGHASH_SINGLE > 0
}

pub(crate) fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, DecodeError> {
// TODO(tobin): Work out why do we do this take, its not done in input or output modules.
let mut r = r.take(MAX_VEC_SIZE as u64);

// TODO: Consider adding protection against memory exhaustion here by defining a maximum
// PBST size and using `take` as we do in rust-bitcoin consensus decoding.
let mut version: Option<Version> = None;
let mut tx_version: Option<transaction::Version> = None;
let mut fallback_lock_time: Option<absolute::LockTime> = None;
Expand Down Expand Up @@ -312,7 +310,7 @@ impl Global {
};

loop {
match raw::Pair::decode(&mut r) {
match raw::Pair::decode(r) {
Ok(pair) => insert_pair(pair)?,
Err(serialize::Error::NoMorePairs) => break,
Err(e) => return Err(DecodeError::DeserPair(e)),
Expand Down
Loading