xlog_cuda/provider/
kernel_paths.rs1use std::env;
2use std::path::{Path, PathBuf};
3
4#[derive(Debug, Clone, Default)]
6pub struct KernelArtifactLocator {
7 cubin_dir: Option<PathBuf>,
8 package_kernels_dir: Option<PathBuf>,
9 out_dir: Option<PathBuf>,
10}
11
12impl KernelArtifactLocator {
13 pub fn new(
14 cubin_dir: Option<PathBuf>,
15 package_kernels_dir: Option<PathBuf>,
16 out_dir: Option<PathBuf>,
17 ) -> Self {
18 Self {
19 cubin_dir,
20 package_kernels_dir,
21 out_dir,
22 }
23 }
24
25 pub fn from_env() -> Self {
32 let cubin_dir = env::var_os("XLOG_CUBIN_DIR").map(PathBuf::from);
33 let package_kernels_dir = env::current_exe()
34 .ok()
35 .and_then(|exe| exe.parent().map(|dir| dir.join("kernels")));
36 let out_dir = option_env!("OUT_DIR").map(PathBuf::from);
37 Self::new(cubin_dir, package_kernels_dir, out_dir)
38 }
39
40 pub fn resolve_module_path(&self, name: &str, cc: u32) -> Option<(PathBuf, bool)> {
42 self.resolve_module_paths(name, cc).into_iter().next()
43 }
44
45 pub fn resolve_module_paths(&self, name: &str, cc: u32) -> Vec<(PathBuf, bool)> {
50 let cubin_name = format!("{name}.sm_{cc}.cubin");
51 let ptx_name = format!("{name}.portable.ptx");
52 let mut found_paths = Vec::new();
53
54 for dir in [
55 self.cubin_dir.as_ref(),
56 self.package_kernels_dir.as_ref(),
57 self.out_dir.as_ref(),
58 ] {
59 let dir = match dir {
60 Some(dir) => dir,
61 None => continue,
62 };
63
64 if let Some(found) = Self::resolve_in_dir(dir, &cubin_name, true) {
65 found_paths.push(found);
66 }
67 if let Some(found) = Self::resolve_in_dir(dir, &ptx_name, false) {
68 found_paths.push(found);
69 }
70 }
71
72 found_paths
73 }
74
75 fn resolve_in_dir(dir: &Path, file_name: &str, is_cubin: bool) -> Option<(PathBuf, bool)> {
76 let path = dir.join(file_name);
77 if path.exists() {
78 Some((path, is_cubin))
79 } else {
80 None
81 }
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::KernelArtifactLocator;
88 use std::fs;
89
90 #[test]
91 fn resolves_in_precedence_order() {
92 let root = std::env::temp_dir().join(format!(
93 "xlog-kernel-paths-{}-{}",
94 std::process::id(),
95 std::time::SystemTime::now()
96 .duration_since(std::time::UNIX_EPOCH)
97 .expect("system clock before UNIX_EPOCH")
98 .as_nanos()
99 ));
100 let cubin_dir = root.join("cubin");
101 let package_dir = root.join("bin").join("kernels");
102 let out_dir = root.join("out");
103 fs::create_dir_all(&cubin_dir).expect("create cubin dir");
104 fs::create_dir_all(&package_dir).expect("create package kernels dir");
105 fs::create_dir_all(&out_dir).expect("create out dir");
106
107 let name = "xlog_join";
108 let cc = 75;
109 let cubin_path = cubin_dir.join(format!("{name}.sm_{cc}.cubin"));
110 let package_path = package_dir.join(format!("{name}.sm_{cc}.cubin"));
111 let out_path = out_dir.join(format!("{name}.sm_{cc}.cubin"));
112 fs::write(&cubin_path, b"cubin").expect("write cubin file");
113 fs::write(&package_path, b"package").expect("write package file");
114 fs::write(&out_path, b"out").expect("write out file");
115
116 let locator = KernelArtifactLocator::new(
117 Some(cubin_dir.clone()),
118 Some(package_dir.clone()),
119 Some(out_dir.clone()),
120 );
121
122 let (path, is_cubin) = locator
123 .resolve_module_path(name, cc)
124 .expect("expected a kernel artifact");
125 assert_eq!(path, cubin_path);
126 assert!(is_cubin);
127
128 fs::remove_file(&cubin_path).expect("remove cubin file");
129 let (path, is_cubin) = locator
130 .resolve_module_path(name, cc)
131 .expect("expected package kernel artifact");
132 assert_eq!(path, package_path);
133 assert!(is_cubin);
134
135 fs::remove_file(&package_path).expect("remove package file");
136 let (path, is_cubin) = locator
137 .resolve_module_path(name, cc)
138 .expect("expected out dir kernel artifact");
139 assert_eq!(path, out_path);
140 assert!(is_cubin);
141
142 let _ = fs::remove_dir_all(&root);
143 }
144}