Skip to content

Commit

Permalink
Demonstrate automatic GPU detection
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Mar 8, 2025
1 parent b239c3e commit e4017b7
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

196 changes: 190 additions & 6 deletions crates/uv-client/src/registry_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::BTreeMap;
use std::fmt::Debug;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::LazyLock;
use std::time::Duration;

use async_http_range_reader::AsyncHttpRangeReader;
Expand Down Expand Up @@ -31,15 +32,145 @@ use uv_metadata::{read_metadata_async_seek, read_metadata_async_stream};
use uv_normalize::PackageName;
use uv_pep440::Version;
use uv_pep508::MarkerEnvironment;
use uv_platform_tags::Platform;
use uv_platform_tags::{Accelerator, Platform};
use uv_pypi_types::{ResolutionMetadata, SimpleJson};
use uv_small_str::SmallString;

// See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213
//
// TODO(charlie): These differ for Windows and Linux.
static DRIVERS: LazyLock<[(Version, Version, IndexUrl); 23]> = LazyLock::new(|| {
[
// Table 2 from
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
(
Version::new([12, 6]),
Version::new([525, 60, 13]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu126").unwrap(),
),
(
Version::new([12, 5]),
Version::new([525, 60, 13]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu125").unwrap(),
),
(
Version::new([12, 4]),
Version::new([525, 60, 13]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu124").unwrap(),
),
(
Version::new([12, 3]),
Version::new([525, 60, 13]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu123").unwrap(),
),
(
Version::new([12, 2]),
Version::new([525, 60, 13]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu122").unwrap(),
),
(
Version::new([12, 1]),
Version::new([525, 60, 13]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu121").unwrap(),
),
(
Version::new([12, 0]),
Version::new([525, 60, 13]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu120").unwrap(),
),
// Table 2 from
// https://docs.nvidia.com/cuda/archive/11.8.0/cuda-toolkit-release-notes/index.html
(
Version::new([11, 8]),
Version::new([450, 80, 2]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap(),
),
(
Version::new([11, 7]),
Version::new([450, 80, 2]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu117").unwrap(),
),
(
Version::new([11, 6]),
Version::new([450, 80, 2]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu116").unwrap(),
),
(
Version::new([11, 5]),
Version::new([450, 80, 2]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu115").unwrap(),
),
(
Version::new([11, 4]),
Version::new([450, 80, 2]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu114").unwrap(),
),
(
Version::new([11, 3]),
Version::new([450, 80, 2]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu113").unwrap(),
),
(
Version::new([11, 2]),
Version::new([450, 80, 2]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu112").unwrap(),
),
(
Version::new([11, 1]),
Version::new([450, 80, 2]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu111").unwrap(),
),
(
Version::new([11, 0]),
Version::new([450, 36, 6]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu110").unwrap(),
),
// Table 1 from
// https://docs.nvidia.com/cuda/archive/10.2/cuda-toolkit-release-notes/index.html
(
Version::new([10, 2]),
Version::new([440, 33]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu102").unwrap(),
),
(
Version::new([10, 1]),
Version::new([418, 39]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu101").unwrap(),
),
(
Version::new([10, 0]),
Version::new([410, 48]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu100").unwrap(),
),
(
Version::new([9, 2]),
Version::new([396, 26]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu92").unwrap(),
),
(
Version::new([9, 1]),
Version::new([390, 46]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu91").unwrap(),
),
(
Version::new([9, 0]),
Version::new([384, 81]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap(),
),
(
Version::new([8, 0]),
Version::new([375, 26]),
IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap(),
),
]
});

/// A builder for an [`RegistryClient`].
#[derive(Debug, Clone)]
pub struct RegistryClientBuilder<'a> {
index_urls: IndexUrls,
index_strategy: IndexStrategy,
platform: Option<&'a Platform>,
cache: Cache,
base_client_builder: BaseClientBuilder<'a>,
}
Expand All @@ -49,6 +180,7 @@ impl RegistryClientBuilder<'_> {
Self {
index_urls: IndexUrls::default(),
index_strategy: IndexStrategy::default(),
platform: None,
cache,
base_client_builder: BaseClientBuilder::new(),
}
Expand Down Expand Up @@ -126,6 +258,7 @@ impl<'a> RegistryClientBuilder<'a> {

#[must_use]
pub fn platform(mut self, platform: &'a Platform) -> Self {
self.platform = Some(platform);
self.base_client_builder = self.base_client_builder.platform(platform);
self
}
Expand All @@ -145,6 +278,10 @@ impl<'a> RegistryClientBuilder<'a> {
RegistryClient {
index_urls: self.index_urls,
index_strategy: self.index_strategy,
accelerator: self
.platform
.and_then(|platform| platform.accelerator())
.cloned(),
cache: self.cache,
connectivity,
client,
Expand All @@ -166,6 +303,10 @@ impl<'a> RegistryClientBuilder<'a> {
RegistryClient {
index_urls: self.index_urls,
index_strategy: self.index_strategy,
accelerator: self
.platform
.and_then(|platform| platform.accelerator())
.cloned(),
cache: self.cache,
connectivity,
client,
Expand All @@ -181,6 +322,7 @@ impl<'a> TryFrom<BaseClientBuilder<'a>> for RegistryClientBuilder<'a> {
Ok(Self {
index_urls: IndexUrls::default(),
index_strategy: IndexStrategy::default(),
platform: None,
cache: Cache::temp()?,
base_client_builder: value,
})
Expand All @@ -194,6 +336,8 @@ pub struct RegistryClient {
index_urls: IndexUrls,
/// The strategy to use when fetching across multiple indexes.
index_strategy: IndexStrategy,
/// The accelerator for the current platform.
accelerator: Option<Accelerator>,
/// The underlying HTTP client.
client: CachedClient,
/// Used for the remote wheel METADATA cache.
Expand Down Expand Up @@ -230,6 +374,44 @@ impl RegistryClient {
self.timeout
}

pub fn index_urls_for(&self, package_name: &PackageName) -> impl Iterator<Item = &IndexUrl> {
// If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
// indexes.
//
// See: https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_patch.py#L36-L49
if matches!(
package_name.as_str(),
"torch"
| "torch-model-archiver"
| "torch-tb-profiler"
| "torcharrow"
| "torchaudio"
| "torchcsprng"
| "torchdata"
| "torchdistx"
| "torchserve"
| "torchtext"
| "torchvision"
| "pytorch-triton"
) {
if let Some(accelerator) = self.accelerator.as_ref() {
return match accelerator {
Accelerator::Cuda { driver_version } => {
Either::Left(DRIVERS.iter().filter_map(move |(cuda, driver, url)| {
if driver_version >= driver {
Some(url)
} else {
None
}
}))
}
};
}
}

Either::Right(self.index_urls.indexes().map(Index::url))
}

/// Fetch a package from the `PyPI` simple API.
///
/// "simple" here refers to [PEP 503 – Simple Repository API](https://peps.python.org/pep-0503/)
Expand All @@ -246,13 +428,15 @@ impl RegistryClient {
let indexes = if let Some(index) = index {
Either::Left(std::iter::once(index))
} else {
Either::Right(self.index_urls.indexes().map(Index::url))
Either::Right(self.index_urls_for(package_name))
};

let mut it = indexes.peekable();
if it.peek().is_none() {
return Err(ErrorKind::NoIndex(package_name.to_string()).into());
}
// let mut it = indexes.peekable();
// if it.peek().is_none() {
// return Err(ErrorKind::NoIndex(package_name.to_string()).into());
// }

let it = indexes;

let mut results = Vec::new();

Expand Down
1 change: 1 addition & 0 deletions crates/uv-platform-tags/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ doctest = false
workspace = true

[dependencies]
uv-pep440 = { workspace = true }
uv-small-str = { workspace = true }

memchr = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion crates/uv-platform-tags/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub use abi_tag::{AbiTag, ParseAbiTagError};
pub use language_tag::{LanguageTag, ParseLanguageTagError};
pub use platform::{Arch, Os, Platform, PlatformError};
pub use platform::{Accelerator, Arch, Os, Platform, PlatformError};
pub use platform_tag::{ParsePlatformTagError, PlatformTag};
pub use tags::{BinaryFormat, IncompatibleTag, TagCompatibility, TagPriority, Tags, TagsError};

Expand Down
28 changes: 27 additions & 1 deletion crates/uv-platform-tags/src/platform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::str::FromStr;
use std::{fmt, io};

use thiserror::Error;
use uv_pep440::Version;

#[derive(Error, Debug)]
pub enum PlatformError {
Expand All @@ -17,12 +18,18 @@ pub enum PlatformError {
pub struct Platform {
os: Os,
arch: Arch,
accelerator: Option<Accelerator>,
}

impl Platform {
/// Create a new platform from the given operating system and architecture.
pub const fn new(os: Os, arch: Arch) -> Self {
Self { os, arch }
Self {
os,
arch,
// Let's track accelerator separately.
accelerator: None,
}
}

/// Return the platform's operating system.
Expand All @@ -34,6 +41,11 @@ impl Platform {
pub fn arch(&self) -> Arch {
self.arch
}

/// Return the platform's accelerator.
pub fn accelerator(&self) -> Option<&Accelerator> {
self.accelerator.as_ref()
}
}

/// All supported operating systems.
Expand Down Expand Up @@ -210,3 +222,17 @@ impl Arch {
.copied()
}
}

#[derive(Debug, Clone, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
#[serde(tag = "name", rename_all = "lowercase")]
pub enum Accelerator {
Cuda { driver_version: Version },
}

impl fmt::Display for Accelerator {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"),
}
}
}
Loading

0 comments on commit e4017b7

Please sign in to comment.