Skip to content

Commit f580454

Browse files
committed
Add proper error handling to new InputPair functions
1 parent 6036d4f commit f580454

File tree

5 files changed

+142
-42
lines changed

5 files changed

+142
-42
lines changed

payjoin/src/psbt.rs

+86-18
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use std::collections::BTreeMap;
44
use std::fmt;
55

6+
use bitcoin::address::FromScriptError;
67
use bitcoin::blockdata::script::Instruction;
78
use bitcoin::psbt::Psbt;
89
use bitcoin::transaction::InputWeightPrediction;
@@ -174,42 +175,44 @@ impl<'a> InputPair<'a> {
174175
}
175176
}
176177

177-
pub fn address_type(&self) -> AddressType {
178-
let txo = self.previous_txout().expect("PrevTxoutError");
178+
pub fn address_type(&self) -> Result<AddressType, AddressTypeError> {
179+
let txo = self.previous_txout()?;
179180
// HACK: Network doesn't matter for our use case of only getting the address type
180181
// but is required in the `from_script` interface. Hardcoded to mainnet.
181-
Address::from_script(&txo.script_pubkey, Network::Bitcoin)
182-
.expect("Unrecognized script")
182+
Address::from_script(&txo.script_pubkey, Network::Bitcoin)?
183183
.address_type()
184-
.expect("UnknownAddressType")
184+
.ok_or(AddressTypeError::UnknownAddressType)
185185
}
186186

187-
pub fn expected_input_weight(&self) -> Weight {
187+
pub fn expected_input_weight(&self) -> Result<Weight, InputWeightError> {
188188
use bitcoin::AddressType::*;
189189

190190
// Get the input weight prediction corresponding to spending an output of this address type
191-
let iwp = match self.address_type() {
192-
P2pkh => InputWeightPrediction::P2PKH_COMPRESSED_MAX,
191+
let iwp = match self.address_type()? {
192+
P2pkh => Ok(InputWeightPrediction::P2PKH_COMPRESSED_MAX),
193193
P2sh =>
194194
match self.psbtin.final_script_sig.as_ref().and_then(|s| redeem_script(s.as_ref()))
195195
{
196-
Some(script) if script.is_witness_program() && script.is_p2wpkh() =>
196+
// Nested segwit p2wpkh.
197197
// input script: 0x160014{20-byte-key-hash} = 23 bytes
198198
// witness: <signature> <pubkey> = 72, 33 bytes
199199
// https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#p2wpkh-nested-in-bip16-p2sh
200-
InputWeightPrediction::new(23, &[72, 33]),
201-
Some(_) => unimplemented!(),
202-
None => panic!("Input not finalized!"),
200+
Some(script) if script.is_witness_program() && script.is_p2wpkh() =>
201+
Ok(InputWeightPrediction::new(23, &[72, 33])),
202+
// Other script or witness program.
203+
Some(_) => Err(InputWeightError::NotSupported),
204+
// No redeem script provided. Cannot determine the script type.
205+
None => Err(InputWeightError::NotFinalized),
203206
},
204-
P2wpkh => InputWeightPrediction::P2WPKH_MAX,
205-
P2wsh => unimplemented!(),
206-
P2tr => InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH,
207-
_ => panic!("Unknown address type!"),
208-
};
207+
P2wpkh => Ok(InputWeightPrediction::P2WPKH_MAX),
208+
P2wsh => Err(InputWeightError::NotSupported),
209+
P2tr => Ok(InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH),
210+
_ => Err(AddressTypeError::UnknownAddressType.into()),
211+
}?;
209212

210213
// Lengths of txid, index and sequence: (32, 4, 4).
211214
let input_weight = iwp.weight() + Weight::from_non_witness_data_size(32 + 4 + 4);
212-
input_weight
215+
Ok(input_weight)
213216
}
214217
}
215218

@@ -279,3 +282,68 @@ impl fmt::Display for PsbtInputsError {
279282
impl std::error::Error for PsbtInputsError {
280283
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { Some(&self.error) }
281284
}
285+
286+
#[derive(Debug)]
287+
pub(crate) enum AddressTypeError {
288+
PrevTxOut(PrevTxOutError),
289+
InvalidScript(FromScriptError),
290+
UnknownAddressType,
291+
}
292+
293+
impl fmt::Display for AddressTypeError {
294+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
295+
match self {
296+
AddressTypeError::PrevTxOut(_) => write!(f, "invalid previous transaction output"),
297+
AddressTypeError::InvalidScript(_) => write!(f, "invalid script"),
298+
AddressTypeError::UnknownAddressType => write!(f, "unknown address type"),
299+
}
300+
}
301+
}
302+
303+
impl std::error::Error for AddressTypeError {
304+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
305+
match self {
306+
AddressTypeError::PrevTxOut(error) => Some(error),
307+
AddressTypeError::InvalidScript(error) => Some(error),
308+
AddressTypeError::UnknownAddressType => None,
309+
}
310+
}
311+
}
312+
313+
impl From<PrevTxOutError> for AddressTypeError {
314+
fn from(value: PrevTxOutError) -> Self { AddressTypeError::PrevTxOut(value) }
315+
}
316+
317+
impl From<FromScriptError> for AddressTypeError {
318+
fn from(value: FromScriptError) -> Self { AddressTypeError::InvalidScript(value) }
319+
}
320+
321+
#[derive(Debug)]
322+
pub(crate) enum InputWeightError {
323+
AddressType(AddressTypeError),
324+
NotFinalized,
325+
NotSupported,
326+
}
327+
328+
impl fmt::Display for InputWeightError {
329+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
330+
match self {
331+
InputWeightError::AddressType(_) => write!(f, "invalid address type"),
332+
InputWeightError::NotFinalized => write!(f, "input not finalized"),
333+
InputWeightError::NotSupported => write!(f, "weight prediction not supported"),
334+
}
335+
}
336+
}
337+
338+
impl std::error::Error for InputWeightError {
339+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
340+
match self {
341+
InputWeightError::AddressType(error) => Some(error),
342+
InputWeightError::NotFinalized => None,
343+
InputWeightError::NotSupported => None,
344+
}
345+
}
346+
}
347+
impl From<AddressTypeError> for InputWeightError {
348+
fn from(value: AddressTypeError) -> Self { InputWeightError::AddressType(value) }
349+
}

payjoin/src/receive/error.rs

+10
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ pub(crate) enum InternalRequestError {
7474
InputOwned(bitcoin::ScriptBuf),
7575
/// The original psbt has mixed input address types that could harm privacy
7676
MixedInputScripts(bitcoin::AddressType, bitcoin::AddressType),
77+
/// The address type could not be determined
78+
AddressType(crate::psbt::AddressTypeError),
79+
/// The expected input weight cannot be determined
80+
InputWeight(crate::psbt::InputWeightError),
7781
/// Original PSBT input has been seen before. Only automatic receivers, aka "interactive" in the spec
7882
/// look out for these to prevent probing attacks.
7983
InputSeen(bitcoin::OutPoint),
@@ -153,6 +157,10 @@ impl fmt::Display for RequestError {
153157
"original-psbt-rejected",
154158
&format!("Mixed input scripts: {}; {}.", type_a, type_b),
155159
),
160+
InternalRequestError::AddressType(e) =>
161+
write_error(f, "original-psbt-rejected", &format!("AddressType Error: {}", e)),
162+
InternalRequestError::InputWeight(e) =>
163+
write_error(f, "original-psbt-rejected", &format!("InputWeight Error: {}", e)),
156164
InternalRequestError::InputSeen(_) =>
157165
write_error(f, "original-psbt-rejected", "The receiver rejected the original PSBT."),
158166
#[cfg(feature = "v2")]
@@ -192,6 +200,8 @@ impl std::error::Error for RequestError {
192200
InternalRequestError::SenderParams(e) => Some(e),
193201
InternalRequestError::InconsistentPsbt(e) => Some(e),
194202
InternalRequestError::PrevTxOut(e) => Some(e),
203+
InternalRequestError::AddressType(e) => Some(e),
204+
InternalRequestError::InputWeight(e) => Some(e),
195205
#[cfg(feature = "v2")]
196206
InternalRequestError::ParsePsbt(e) => Some(e),
197207
#[cfg(feature = "v2")]

payjoin/src/receive/mod.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,10 @@ impl MaybeMixedInputScripts {
228228
let input_scripts = self
229229
.psbt
230230
.input_pairs()
231-
.scan(&mut err, |err, input| match Ok(input.address_type()) {
231+
.scan(&mut err, |err, input| match input.address_type() {
232232
Ok(address_type) => Some(address_type),
233233
Err(e) => {
234-
**err = Err(RequestError::from(InternalRequestError::PrevTxOut(e)));
234+
**err = Err(RequestError::from(InternalRequestError::AddressType(e)));
235235
None
236236
}
237237
})
@@ -750,7 +750,8 @@ impl ProvisionalProposal {
750750
// Calculate the additional weight contribution
751751
let input_count = self.payjoin_psbt.inputs.len() - self.original_psbt.inputs.len();
752752
log::trace!("input_count : {}", input_count);
753-
let weight_per_input = input_pair.expected_input_weight();
753+
let weight_per_input =
754+
input_pair.expected_input_weight().map_err(InternalRequestError::InputWeight)?;
754755
log::trace!("weight_per_input : {}", weight_per_input);
755756
let contribution_weight = weight_per_input * input_count as u64;
756757
log::trace!("contribution_weight: {}", contribution_weight);

payjoin/src/send/error.rs

+15-6
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub struct ValidationError {
1717
pub(crate) enum InternalValidationError {
1818
Parse,
1919
Io(std::io::Error),
20-
InvalidProposedInput(crate::psbt::PrevTxOutError),
20+
InvalidAddressType(crate::psbt::AddressTypeError),
2121
VersionsDontMatch {
2222
proposed: Version,
2323
original: Version,
@@ -66,14 +66,20 @@ impl From<InternalValidationError> for ValidationError {
6666
fn from(value: InternalValidationError) -> Self { ValidationError { internal: value } }
6767
}
6868

69+
impl From<crate::psbt::AddressTypeError> for InternalValidationError {
70+
fn from(value: crate::psbt::AddressTypeError) -> Self {
71+
InternalValidationError::InvalidAddressType(value)
72+
}
73+
}
74+
6975
impl fmt::Display for ValidationError {
7076
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
7177
use InternalValidationError::*;
7278

7379
match &self.internal {
7480
Parse => write!(f, "couldn't decode as PSBT or JSON",),
7581
Io(e) => write!(f, "couldn't read PSBT: {}", e),
76-
InvalidProposedInput(e) => write!(f, "invalid proposed transaction input: {}", e),
82+
InvalidAddressType(e) => write!(f, "invalid input address type: {}", e),
7783
VersionsDontMatch { proposed, original, } => write!(f, "proposed transaction version {} doesn't match the original {}", proposed, original),
7884
LockTimesDontMatch { proposed, original, } => write!(f, "proposed transaction lock time {} doesn't match the original {}", proposed, original),
7985
SenderTxinSequenceChanged { proposed, original, } => write!(f, "proposed transaction sequence number {} doesn't match the original {}", proposed, original),
@@ -115,7 +121,7 @@ impl std::error::Error for ValidationError {
115121
match &self.internal {
116122
Parse => None,
117123
Io(error) => Some(error),
118-
InvalidProposedInput(error) => Some(error),
124+
InvalidAddressType(error) => Some(error),
119125
VersionsDontMatch { proposed: _, original: _ } => None,
120126
LockTimesDontMatch { proposed: _, original: _ } => None,
121127
SenderTxinSequenceChanged { proposed: _, original: _ } => None,
@@ -172,7 +178,8 @@ pub(crate) enum InternalCreateRequestError {
172178
ChangeIndexOutOfBounds,
173179
ChangeIndexPointsAtPayee,
174180
Url(url::ParseError),
175-
PrevTxOut(crate::psbt::PrevTxOutError),
181+
AddressType(crate::psbt::AddressTypeError),
182+
InputWeight(crate::psbt::InputWeightError),
176183
#[cfg(feature = "v2")]
177184
Hpke(crate::v2::HpkeError),
178185
#[cfg(feature = "v2")]
@@ -202,7 +209,8 @@ impl fmt::Display for CreateRequestError {
202209
ChangeIndexOutOfBounds => write!(f, "fee output index is points out of bounds"),
203210
ChangeIndexPointsAtPayee => write!(f, "fee output index is points at output belonging to the payee"),
204211
Url(e) => write!(f, "cannot parse url: {:#?}", e),
205-
PrevTxOut(e) => write!(f, "invalid previous transaction output: {}", e),
212+
AddressType(e) => write!(f, "can not determine input address type: {}", e),
213+
InputWeight(e) => write!(f, "can not determine expected input weight: {}", e),
206214
#[cfg(feature = "v2")]
207215
Hpke(e) => write!(f, "v2 error: {}", e),
208216
#[cfg(feature = "v2")]
@@ -234,7 +242,8 @@ impl std::error::Error for CreateRequestError {
234242
ChangeIndexOutOfBounds => None,
235243
ChangeIndexPointsAtPayee => None,
236244
Url(error) => Some(error),
237-
PrevTxOut(error) => Some(error),
245+
AddressType(error) => Some(error),
246+
InputWeight(error) => Some(error),
238247
#[cfg(feature = "v2")]
239248
Hpke(error) => Some(error),
240249
#[cfg(feature = "v2")]

payjoin/src/send/mod.rs

+27-15
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,22 @@ impl<'a> RequestBuilder<'a> {
133133

134134
let first_input_pair =
135135
input_pairs.first().ok_or(InternalCreateRequestError::NoInputs)?;
136-
// use cheapest default if mixed input types
137-
let mut input_weight =
138-
bitcoin::transaction::InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH.weight()
139-
// Lengths of txid, index and sequence: (32, 4, 4).
140-
+ Weight::from_non_witness_data_size(32 + 4 + 4);
141-
// Check if all inputs are the same type
142-
if input_pairs
136+
let input_weight = if input_pairs
143137
.iter()
144-
.all(|input_pair| input_pair.address_type() == first_input_pair.address_type())
138+
.try_fold(true, |_, input_pair| -> Result<bool, crate::psbt::AddressTypeError> {
139+
Ok(input_pair.address_type()? == first_input_pair.address_type()?)
140+
})
141+
.map_err(InternalCreateRequestError::AddressType)?
145142
{
146-
input_weight = first_input_pair.expected_input_weight();
147-
}
143+
first_input_pair
144+
.expected_input_weight()
145+
.map_err(InternalCreateRequestError::InputWeight)?
146+
} else {
147+
// use cheapest default if mixed input types
148+
bitcoin::transaction::InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH.weight()
149+
// Lengths of txid, index and sequence: (32, 4, 4).
150+
+ Weight::from_non_witness_data_size(32 + 4 + 4)
151+
};
148152

149153
let recommended_additional_fee = min_fee_rate * input_weight;
150154
if fee_available < recommended_additional_fee {
@@ -230,7 +234,10 @@ impl<'a> RequestBuilder<'a> {
230234
let zeroth_input = psbt.input_pairs().next().ok_or(InternalCreateRequestError::NoInputs)?;
231235

232236
let sequence = zeroth_input.txin.sequence;
233-
let input_type = zeroth_input.address_type().to_string();
237+
let input_type = zeroth_input
238+
.address_type()
239+
.map_err(InternalCreateRequestError::AddressType)?
240+
.to_string();
234241

235242
#[cfg(feature = "v2")]
236243
let e = {
@@ -620,12 +627,17 @@ impl ContextV1 {
620627
ensure!(contributed_fee <= proposed_fee - original_fee, PayeeTookContributedFee);
621628
let original_weight = self.original_psbt.clone().extract_tx_unchecked_fee_rate().weight();
622629
let original_fee_rate = original_fee / original_weight;
623-
// TODO: Refactor this to be support mixed input types, preferably share method with
624-
// `ProvisionalProposal::additional_input_weight()`
630+
// TODO: This should support mixed input types
625631
ensure!(
626632
contributed_fee
627633
<= original_fee_rate
628-
* self.original_psbt.input_pairs().next().unwrap().expected_input_weight()
634+
* self
635+
.original_psbt
636+
.input_pairs()
637+
.next()
638+
.expect("This shouldn't happen. Failed to get an original input.")
639+
.expected_input_weight()
640+
.expect("This shouldn't happen. Weight should have been calculated successfully before.")
629641
* (proposal.inputs.len() - self.original_psbt.inputs.len()) as u64,
630642
FeeContributionPaysOutputSizeIncrease
631643
);
@@ -697,7 +709,7 @@ impl ContextV1 {
697709
ReceiverTxinMissingUtxoInfo
698710
);
699711
ensure!(proposed.txin.sequence == self.sequence, MixedSequence);
700-
check_eq!(proposed.address_type(), self.input_type, MixedInputTypes);
712+
check_eq!(proposed.address_type()?, self.input_type, MixedInputTypes);
701713
}
702714
}
703715
}

0 commit comments

Comments
 (0)