diff --git a/Cargo.toml b/Cargo.toml index 53ab13d1..fecf9a26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -176,7 +176,6 @@ hyper-util = { version = "0.1.10", features = [ "server-auto", "tokio", ] } -env_logger = "0.11.6" serde = { version = "1.0", features = ["derive"] } libflate = "2.0.0" zstd = { version = "0.13" } @@ -190,6 +189,10 @@ tower = { version = "0.5.2", default-features = false, features = ["limit"] } num_cpus = "1.0" libc = "0.2" +env_logger = "0.11.6" +tracing = "0.1" +tracing-subscriber = "0.3.19" + [lib] doctest = false @@ -309,7 +312,7 @@ required-features = ["native-roots", "webpki-roots"] [[example]] name = "websocket" path = "examples/websocket.rs" -required-features = ["websocket", "futures-util/std"] +required-features = ["websocket", "http2-tracing", "futures-util/std"] [[example]] name = "client_chain" @@ -358,3 +361,8 @@ required-features = ["full"] name = "request_with_interface" path = "examples/request_with_interface.rs" required-features = ["full"] + +[[example]] +name = "http2_websocket" +path = "examples/http2_websocket.rs" +required-features = ["websocket", "http2-tracing", "futures-util/std"] diff --git a/examples/base_url.rs b/examples/base_url.rs index bd0ec25e..22861027 100644 --- a/examples/base_url.rs +++ b/examples/base_url.rs @@ -2,7 +2,9 @@ use rquest::{Client, Impersonate}; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("debug")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Edge131 let mut client = Client::builder() diff --git a/examples/http2_websocket.rs b/examples/http2_websocket.rs new file mode 100644 index 00000000..87ca603a --- /dev/null +++ b/examples/http2_websocket.rs @@ -0,0 +1,53 @@ +use futures_util::{SinkExt, StreamExt, TryStreamExt}; +use http::header; +use rquest::{Client, Impersonate, Message, RequestBuilder}; +use std::time::Duration; + +#[tokio::main] +async fn main() -> Result<(), rquest::Error> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); + + // Build a client to impersonate Firefox133 + let client = Client::builder() + .impersonate(Impersonate::Firefox133) + .danger_accept_invalid_certs(true) + .build()?; + + // Use the API you're already familiar with + let websocket = client + .websocket("wss://127.0.0.1:3000/ws") + .configure_request(configure_request) + .http2_only() + .send() + .await?; + + assert_eq!(websocket.version(), http::Version::HTTP_2); + + let (mut tx, mut rx) = websocket.into_websocket().await?.split(); + + tokio::spawn(async move { + for i in 1..11 { + tx.send(Message::Text(format!("Hello, World! #{i}"))) + .await + .unwrap(); + } + }); + + while let Some(message) = rx.try_next().await? { + match message { + Message::Text(text) => println!("received: {text}"), + _ => {} + } + } + + Ok(()) +} + +/// We can also set HTTP options here +fn configure_request(builder: RequestBuilder) -> RequestBuilder { + builder + .header(header::USER_AGENT, env!("CARGO_PKG_NAME")) + .timeout(Duration::from_secs(10)) +} diff --git a/examples/impersonate_builder.rs b/examples/impersonate_builder.rs index c72d4196..e29a0845 100644 --- a/examples/impersonate_builder.rs +++ b/examples/impersonate_builder.rs @@ -2,7 +2,9 @@ use rquest::{Client, Impersonate, ImpersonateOS}; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Firefox128 let impersonate = Impersonate::builder() diff --git a/examples/impersonate_psk.rs b/examples/impersonate_psk.rs index c29a880c..602e0b28 100644 --- a/examples/impersonate_psk.rs +++ b/examples/impersonate_psk.rs @@ -2,7 +2,9 @@ use rquest::Impersonate; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Firefox133 let client = rquest::Client::builder() diff --git a/examples/impersonate_settings.rs b/examples/impersonate_settings.rs index dcc457e5..932e34eb 100644 --- a/examples/impersonate_settings.rs +++ b/examples/impersonate_settings.rs @@ -112,7 +112,9 @@ const HEADER_ORDER: &[HeaderName] = &[ #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // TLS settings let tls = TlsSettings::builder() diff --git a/examples/request_with_cookie_store.rs b/examples/request_with_cookie_store.rs index dafdf741..1fc0cdca 100644 --- a/examples/request_with_cookie_store.rs +++ b/examples/request_with_cookie_store.rs @@ -9,7 +9,9 @@ use url::Url; #[tokio::main] async fn main() -> Result<(), Box> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); let url = Url::parse("https://google.com/")?; diff --git a/examples/request_with_interface.rs b/examples/request_with_interface.rs index e9049827..32d67212 100644 --- a/examples/request_with_interface.rs +++ b/examples/request_with_interface.rs @@ -2,7 +2,9 @@ use rquest::{Client, Impersonate}; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Firefox128 let client = Client::builder() diff --git a/examples/request_with_local_address.rs b/examples/request_with_local_address.rs index 15cb2ebf..f7cac09b 100644 --- a/examples/request_with_local_address.rs +++ b/examples/request_with_local_address.rs @@ -3,7 +3,9 @@ use std::net::IpAddr; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Safari18 let client = rquest::Client::builder() diff --git a/examples/request_with_proxy.rs b/examples/request_with_proxy.rs index 0d6426e3..3338cd44 100644 --- a/examples/request_with_proxy.rs +++ b/examples/request_with_proxy.rs @@ -2,7 +2,9 @@ use rquest::Impersonate; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Firefox133 let client = rquest::Client::builder() diff --git a/examples/request_with_redirect.rs b/examples/request_with_redirect.rs index 330cfd05..ff8b94e6 100644 --- a/examples/request_with_redirect.rs +++ b/examples/request_with_redirect.rs @@ -2,7 +2,9 @@ use rquest::{redirect::Policy, Impersonate}; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Safari18 let client = rquest::Client::builder() diff --git a/examples/request_with_version.rs b/examples/request_with_version.rs index b6d5da32..5af3e5f2 100644 --- a/examples/request_with_version.rs +++ b/examples/request_with_version.rs @@ -3,7 +3,9 @@ use rquest::{redirect::Policy, Impersonate}; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Safari18 let client = rquest::Client::builder() diff --git a/examples/set_proxies.rs b/examples/set_proxies.rs index 4e4882fc..ba214165 100644 --- a/examples/set_proxies.rs +++ b/examples/set_proxies.rs @@ -2,7 +2,9 @@ use rquest::{Client, Impersonate}; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Chrome130 let mut client = Client::builder() diff --git a/examples/set_redirect.rs b/examples/set_redirect.rs index 84ea9675..20f60bbe 100644 --- a/examples/set_redirect.rs +++ b/examples/set_redirect.rs @@ -2,7 +2,9 @@ use rquest::{redirect::Policy, Impersonate}; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); // Build a client to impersonate Safari18 let mut client = rquest::Client::builder() diff --git a/examples/set_root_cert_store.rs b/examples/set_root_cert_store.rs index 4a1f3538..0c74cc87 100644 --- a/examples/set_root_cert_store.rs +++ b/examples/set_root_cert_store.rs @@ -4,7 +4,9 @@ use std::sync::LazyLock; #[tokio::main] async fn main() -> Result<(), rquest::Error> { - env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); use_static_root_certs().await?; use_dynamic_root_certs().await?; Ok(()) diff --git a/examples/websocket.rs b/examples/websocket.rs index 0cba9e77..91deb4e3 100644 --- a/examples/websocket.rs +++ b/examples/websocket.rs @@ -1,26 +1,30 @@ -use std::time::Duration; - use futures_util::{SinkExt, StreamExt, TryStreamExt}; use http::header; use rquest::{Client, Impersonate, Message, RequestBuilder}; +use std::time::Duration; #[tokio::main] async fn main() -> Result<(), rquest::Error> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); + // Build a client to impersonate Firefox133 let client = Client::builder() .impersonate(Impersonate::Firefox133) + .danger_accept_invalid_certs(true) .build()?; // Use the API you're already familiar with let websocket = client - .websocket("wss://echo.websocket.org") + .websocket("wss://127.0.0.1:3000/ws") .configure_request(configure_request) .send() - .await? - .into_websocket() .await?; - let (mut tx, mut rx) = websocket.split(); + assert_eq!(websocket.version(), http::Version::HTTP_11); + + let (mut tx, mut rx) = websocket.into_websocket().await?.split(); tokio::spawn(async move { for i in 1..11 { @@ -43,7 +47,6 @@ async fn main() -> Result<(), rquest::Error> { /// We can also set HTTP options here fn configure_request(builder: RequestBuilder) -> RequestBuilder { builder - .proxy("http://127.0.0.1:6152") .header(header::USER_AGENT, env!("CARGO_PKG_NAME")) .timeout(Duration::from_secs(10)) } diff --git a/src/client/http.rs b/src/client/http.rs index 36b26f5d..d1612dfa 100644 --- a/src/client/http.rs +++ b/src/client/http.rs @@ -1298,6 +1298,7 @@ impl Client { redirect, _cookie_store, network_scheme, + protocal, ) = req.pieces(); if url.scheme() != "http" && url.scheme() != "https" { @@ -1357,12 +1358,13 @@ impl Client { let in_flight = { let res = InnerRequest::builder() - .network_scheme(network_scheme.clone()) .uri(uri) .method(method.clone()) .version(version) .headers(headers.clone()) .headers_order(self.inner.headers_order.as_deref()) + .network_scheme(network_scheme.clone()) + .extension(protocal) .body(body); match res { @@ -1944,12 +1946,12 @@ impl PendingRequest { *self.as_mut().in_flight().get_mut() = { let res = InnerRequest::builder() - .network_scheme(self.network_scheme.clone()) .uri(uri) .method(self.method.clone()) .version(self.version) .headers(self.headers.clone()) .headers_order(self.client.headers_order.as_deref()) + .network_scheme(self.network_scheme.clone()) .body(body); if let Ok(req) = res { @@ -2203,12 +2205,12 @@ impl Future for PendingRequest { *self.as_mut().in_flight().get_mut() = { let req = InnerRequest::builder() - .network_scheme(self.network_scheme.clone()) .uri(uri) .method(self.method.clone()) .version(self.version) .headers(headers.clone()) .headers_order(self.client.headers_order.as_deref()) + .network_scheme(self.network_scheme.clone()) .body(body)?; std::mem::swap(self.as_mut().headers(), &mut headers); diff --git a/src/client/request.rs b/src/client/request.rs index c9a14d49..f7f560fe 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -32,6 +32,7 @@ type PiecesWithCookieStore = ( Option, (), NetworkScheme, + Option, ); #[cfg(feature = "cookies")] @@ -46,6 +47,7 @@ type PiecesWithCookieStore = ( Option, Option>, NetworkScheme, + Option, ); /// A request which can be executed with `Client::execute()`. @@ -61,6 +63,7 @@ pub struct Request { #[cfg(feature = "cookies")] cookie_store: Option>, network_scheme: NetworkSchemeBuilder, + protocol: Option, } /// A builder to construct the properties of a `Request`. @@ -88,6 +91,7 @@ impl Request { #[cfg(feature = "cookies")] cookie_store: None, network_scheme: NetworkScheme::builder(), + protocol: None, } } @@ -194,6 +198,12 @@ impl Request { &mut self.version } + /// Set the mutable reference to the protocol. + #[inline] + pub fn protocol_mut(&mut self) -> &mut Option { + &mut self.protocol + } + /// Attempt to clone the request. /// /// `None` is returned if the request can not be cloned, i.e. if the body is a stream. @@ -232,6 +242,7 @@ impl Request { #[cfg(not(feature = "cookies"))] (), self.network_scheme.build(), + self.protocol, ) } } @@ -824,6 +835,7 @@ where #[cfg(feature = "cookies")] cookie_store: None, network_scheme: NetworkScheme::builder(), + protocol: None, }) } } diff --git a/src/client/websocket/mod.rs b/src/client/websocket/mod.rs index de574696..9e803a3b 100644 --- a/src/client/websocket/mod.rs +++ b/src/client/websocket/mod.rs @@ -16,7 +16,7 @@ use crate::{ use crate::{Error, Response}; use async_tungstenite::tungstenite::{self, protocol}; use futures_util::{Sink, SinkExt, Stream, StreamExt}; -use http::{header, uri::Scheme, HeaderMap, HeaderName, HeaderValue, StatusCode, Version}; +use http::{header, uri::Scheme, HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Version}; pub use message::{CloseCode, Message}; use tokio_util::compat::TokioAsyncReadCompatExt; use tungstenite::protocol::WebSocketConfig; @@ -24,6 +24,14 @@ use tungstenite::protocol::WebSocketConfig; pub type WebSocketStream = async_tungstenite::WebSocketStream>; +/// A marker to identify what version a connection is. +#[derive(Debug, Default)] +pub enum Ver { + #[default] + Http1, + Http2, +} + /// Wrapper for [`RequestBuilder`] that performs the /// websocket handshake when sent. #[derive(Debug)] @@ -32,6 +40,7 @@ pub struct WebSocketRequestBuilder { nonce: Option>, protocols: Option>>, config: WebSocketConfig, + ver: Ver, } impl WebSocketRequestBuilder { @@ -41,6 +50,7 @@ impl WebSocketRequestBuilder { nonce: None, protocols: None, config: WebSocketConfig::default(), + ver: Ver::Http1, } } @@ -157,14 +167,17 @@ impl WebSocketRequestBuilder { self } + /// Sets the HTTP version to HTTP/2 for the WebSocket connection. + pub fn http2_only(mut self) -> Self { + self.ver = Ver::Http2; + self + } + /// Sends the request and returns and [`WebSocketResponse`]. pub async fn send(self) -> Result { let (client, request) = self.inner.build_split(); let mut request = request?; - // Ensure the request is HTTP 1.1 - *request.version_mut() = Some(Version::HTTP_11); - // Ensure the scheme is http or https let url = request.url_mut(); match url.scheme() { @@ -183,21 +196,36 @@ impl WebSocketRequestBuilder { } } - // Generate a nonce if one wasn't provided - let nonce = self - .nonce - .unwrap_or_else(|| Cow::Owned(tungstenite::handshake::client::generate_key())); - - // HTTP 1 requires us to set some headers. let headers = request.headers_mut(); - headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade")); - headers.insert(header::UPGRADE, HeaderValue::from_static("websocket")); - headers.insert(header::SEC_WEBSOCKET_KEY, HeaderValue::from_str(&nonce)?); headers.insert( header::SEC_WEBSOCKET_VERSION, HeaderValue::from_static("13"), ); + // Ensure the request is HTTP 1.1/HTTP 2 + let nonce = match self.ver { + Ver::Http1 => { + // Generate a nonce if one wasn't provided + let nonce = self + .nonce + .unwrap_or_else(|| Cow::Owned(tungstenite::handshake::client::generate_key())); + + headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade")); + headers.insert(header::UPGRADE, HeaderValue::from_static("websocket")); + headers.insert(header::SEC_WEBSOCKET_KEY, HeaderValue::from_str(&nonce)?); + + *request.method_mut() = Method::GET; + *request.version_mut() = Some(Version::HTTP_11); + Some(nonce) + } + Ver::Http2 => { + *request.method_mut() = Method::CONNECT; + *request.version_mut() = Some(Version::HTTP_2); + *request.protocol_mut() = Some(hyper2::ext::Protocol::from_static("websocket")); + None + } + }; + // Set websocket subprotocols if let Some(ref protocols) = self.protocols { // Sets subprotocols @@ -222,6 +250,7 @@ impl WebSocketRequestBuilder { nonce, protocols: self.protocols, config: self.config, + var: self.ver, }) } } @@ -233,9 +262,10 @@ impl WebSocketRequestBuilder { #[derive(Debug)] pub struct WebSocketResponse { inner: Response, - nonce: Cow<'static, str>, + nonce: Option>, protocols: Option>>, config: WebSocketConfig, + var: Ver, } impl Deref for WebSocketResponse { @@ -260,48 +290,65 @@ impl WebSocketResponse { let status = self.inner.status(); let headers = self.inner.headers(); - if !matches!(self.inner.version(), Version::HTTP_10 | Version::HTTP_11) { + if !matches!( + self.inner.version(), + Version::HTTP_10 | Version::HTTP_11 | Version::HTTP_2 + ) { return Err(Error::new( Kind::Upgrade, Some(format!("unexpected version: {:?}", self.inner.version())), )); } - if status != StatusCode::SWITCHING_PROTOCOLS { - return Err(Error::new( - Kind::Upgrade, - Some(format!("unexpected status code: {}", status)), - )); - } + match self.var { + Ver::Http1 => { + if status != StatusCode::SWITCHING_PROTOCOLS { + let body = self.inner.text().await?; + return Err(Error::new( + Kind::Upgrade, + Some(format!("unexpected status code: {}", body)), + )); + } - if !header_contains(self.inner.headers(), header::CONNECTION, "upgrade") { - log::debug!("missing Connection header"); - return Err(Error::new(Kind::Upgrade, Some("missing connection header"))); - } + if !header_contains(self.inner.headers(), header::CONNECTION, "upgrade") { + log::debug!("missing Connection header"); + return Err(Error::new(Kind::Upgrade, Some("missing connection header"))); + } - if !header_eq(self.inner.headers(), header::UPGRADE, "websocket") { - log::debug!("server responded with invalid Upgrade header"); - return Err(Error::new(Kind::Upgrade, Some("invalid upgrade header"))); - } + if !header_eq(self.inner.headers(), header::UPGRADE, "websocket") { + log::debug!("server responded with invalid Upgrade header"); + return Err(Error::new(Kind::Upgrade, Some("invalid upgrade header"))); + } - match headers.get(header::SEC_WEBSOCKET_ACCEPT) { - Some(header) => { - if !header.to_str().is_ok_and(|s| { - s == tungstenite::handshake::derive_accept_key(self.nonce.as_bytes()) - }) { - log::debug!( + match self.nonce.zip(headers.get(header::SEC_WEBSOCKET_ACCEPT)) { + Some((nonce, header)) => { + if !header.to_str().is_ok_and(|s| { + s == tungstenite::handshake::derive_accept_key(nonce.as_bytes()) + }) { + log::debug!( "server responded with invalid Sec-Websocket-Accept header: {header:?}" ); + return Err(Error::new( + Kind::Upgrade, + Some(format!("invalid accept key: {:?}", header)), + )); + } + } + None => { + log::debug!("missing Sec-Websocket-Accept header"); + return Err(Error::new(Kind::Upgrade, Some("missing accept key"))); + } + } + } + Ver::Http2 => { + if status != StatusCode::OK { + let body = self.inner.text().await?; return Err(Error::new( Kind::Upgrade, - Some(format!("invalid accept key: {:?}", header)), + Some(format!("unexpected status code: {}", body)), )); } } - None => { - log::debug!("missing Sec-Websocket-Accept header"); - return Err(Error::new(Kind::Upgrade, Some("missing accept key"))); - } } let protocol = headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); @@ -317,7 +364,7 @@ impl WebSocketResponse { (false, None) => { // server didn't reply with a protocol return Err(Error::new( - Kind::Status(self.res.status()), + Kind::Status(self.inner.status()), Some("missing protocol"), )); } @@ -339,7 +386,7 @@ impl WebSocketResponse { (true, Some(_)) => { // we didn't request any protocols but got one anyway return Err(Error::new( - Kind::Status(self.res.status()), + Kind::Status(self.inner.status()), Some("invalid protocol"), )); } diff --git a/src/util/client/mod.rs b/src/util/client/mod.rs index 3c54663c..51cf2504 100644 --- a/src/util/client/mod.rs +++ b/src/util/client/mod.rs @@ -433,7 +433,7 @@ where } else { origin_form(req.uri_mut()); } - } else if req.method() == Method::CONNECT { + } else if req.method() == Method::CONNECT && !pooled.is_http2() { authority_form(req.uri_mut()); } diff --git a/src/util/client/request.rs b/src/util/client/request.rs index 3608590c..834e518d 100644 --- a/src/util/client/request.rs +++ b/src/util/client/request.rs @@ -7,7 +7,7 @@ use http::{ Request, Uri, Version, }; use http_body::Body; -use std::marker::PhantomData; +use std::{any::Any, marker::PhantomData}; pub struct InnerRequest where @@ -101,6 +101,18 @@ where self } + /// Set the extension for the request. + #[inline] + pub fn extension(mut self, extension: Option) -> Self + where + T: Clone + Any + Send + Sync + 'static, + { + if let Some(extension) = extension { + self.builder = self.builder.extension(extension); + } + self + } + /// Set network scheme for the request. #[inline] pub fn network_scheme(mut self, network_scheme: NetworkScheme) -> Self {