1use crate::harness::{CategoryResult, TestContext, TestResult};
11use std::collections::HashSet;
12use std::fs;
13use std::time::Instant;
14use xlog_cuda::{join_kernels, provider::kernel_paths::KernelArtifactLocator, JOIN_MODULE};
15
16fn kernel_locator() -> KernelArtifactLocator {
17 KernelArtifactLocator::from_env()
18}
19
20fn extract_ptx_directive(ptx: &str, directive: &str) -> Option<String> {
21 for line in ptx.lines() {
22 let line = line.trim();
23 if let Some(rest) = line.strip_prefix(directive) {
24 let rest = rest.trim();
25 if rest.is_empty() {
26 continue;
27 }
28 return rest.split_whitespace().next().map(|s| s.to_string());
29 }
30 }
31 None
32}
33
34fn extract_entry_names(ptx: &str) -> Vec<String> {
35 let mut entries = Vec::new();
36 for line in ptx.lines() {
37 let line = line.trim();
38 let line = if let Some(rest) = line.strip_prefix(".visible .entry ") {
39 rest
40 } else if let Some(rest) = line.strip_prefix(".entry ") {
41 rest
42 } else {
43 continue;
44 };
45
46 if let Some((name, _)) = line.split_once('(') {
47 let name = name.trim();
48 if !name.is_empty() {
49 entries.push(name.to_string());
50 }
51 }
52 }
53 entries
54}
55
56fn parse_sm_target(target: &str) -> Option<u32> {
57 let s = target.trim();
58 let sm = s.strip_prefix("sm_")?;
59 sm.parse::<u32>().ok()
60}
61
62pub fn run_all(ctx: &TestContext) -> CategoryResult {
64 let mut results = CategoryResult::new("c01_toolchain");
65 let start = Instant::now();
66
67 results.add_result(test_ptx_loads_successfully(ctx));
68 results.add_result(test_compute_capability_check(ctx));
69 results.add_result(test_kernel_function_resolution(ctx));
70 results.add_result(test_ptx_module_attributes(ctx));
71 results.add_result(test_repeated_jit_compilation(ctx));
72
73 results.set_duration(start.elapsed());
74 results
75}
76
77fn test_ptx_loads_successfully(ctx: &TestContext) -> TestResult {
83 let start = Instant::now();
84
85 if let Err(e) = ctx.sync_and_check() {
89 return TestResult::error(
90 "test_ptx_loads_successfully",
91 start.elapsed(),
92 format!("Device sync after PTX load failed: {}", e),
93 );
94 }
95
96 let device_ordinal = ctx.device.ordinal();
98 if device_ordinal > 100 {
99 return TestResult::error(
101 "test_ptx_loads_successfully",
102 start.elapsed(),
103 format!("Invalid device ordinal: {}", device_ordinal),
104 );
105 }
106
107 TestResult::passed("test_ptx_loads_successfully", start.elapsed())
108}
109
110fn test_compute_capability_check(ctx: &TestContext) -> TestResult {
115 let start = Instant::now();
116
117 let (major, minor) = match ctx.compute_capability() {
118 Ok(v) => v,
119 Err(e) => {
120 return TestResult::error(
121 "test_compute_capability_check",
122 start.elapsed(),
123 format!("Failed to query compute capability: {}", e),
124 );
125 }
126 };
127
128 let meets_minimum = major >= 7;
131
132 if !meets_minimum {
133 return TestResult::error(
134 "test_compute_capability_check",
135 start.elapsed(),
136 format!(
137 "Compute capability {}.{} does not meet minimum sm_70 requirement",
138 major, minor
139 ),
140 );
141 }
142
143 if let Err(e) = ctx.sync_and_check() {
144 return TestResult::error(
145 "test_compute_capability_check",
146 start.elapsed(),
147 format!("Sync failed: {}", e),
148 );
149 }
150
151 TestResult::passed("test_compute_capability_check", start.elapsed())
152}
153
154fn test_kernel_function_resolution(ctx: &TestContext) -> TestResult {
160 let start = Instant::now();
161
162 let device = ctx.device.inner();
163 let locator = kernel_locator();
164
165 let mut total_functions = 0;
166 let mut resolved_functions = 0;
167
168 for spec in xlog_cuda::kernel_manifest_data::KERNEL_MODULES {
169 let (path, is_cubin) = match locator.resolve_module_path(spec.cu_name, 999) {
170 Some(v) => v,
171 None => {
172 return TestResult::error(
173 "test_kernel_function_resolution",
174 start.elapsed(),
175 format!(
176 "{}: no portable PTX found in XLOG_CUBIN_DIR, package kernels/, or OUT_DIR",
177 spec.cu_name
178 ),
179 );
180 }
181 };
182 if is_cubin {
183 return TestResult::error(
184 "test_kernel_function_resolution",
185 start.elapsed(),
186 format!(
187 "{}: expected portable PTX fallback when resolving module artifacts",
188 spec.cu_name
189 ),
190 );
191 }
192
193 let filename = path
194 .file_name()
195 .and_then(|s| s.to_str())
196 .unwrap_or("<unknown>");
197 let module_name = spec.module_name;
198
199 let ptx = match fs::read_to_string(&path) {
200 Ok(s) => s,
201 Err(e) => {
202 return TestResult::error(
203 "test_kernel_function_resolution",
204 start.elapsed(),
205 format!("Failed to read {}: {}", path.display(), e),
206 );
207 }
208 };
209
210 let address_size = extract_ptx_directive(&ptx, ".address_size").unwrap_or_default();
211 if address_size != "64" {
212 return TestResult::error(
213 "test_kernel_function_resolution",
214 start.elapsed(),
215 format!(
216 "{}: expected .address_size 64, got '{}'",
217 filename, address_size
218 ),
219 );
220 }
221
222 let target = extract_ptx_directive(&ptx, ".target").unwrap_or_default();
223 let sm = parse_sm_target(&target).unwrap_or(0);
224 if sm < 70 {
225 return TestResult::error(
226 "test_kernel_function_resolution",
227 start.elapsed(),
228 format!(
229 "{}: expected .target sm_70 or later, got '{}'",
230 filename, target
231 ),
232 );
233 }
234
235 let entries = extract_entry_names(&ptx);
236 if entries.is_empty() {
237 return TestResult::error(
238 "test_kernel_function_resolution",
239 start.elapsed(),
240 format!("{}: no .entry kernels found", filename),
241 );
242 }
243
244 let mut seen = HashSet::new();
245 for entry in &entries {
246 if !seen.insert(entry.as_str()) {
247 return TestResult::error(
248 "test_kernel_function_resolution",
249 start.elapsed(),
250 format!("{}: duplicate .entry name {}", filename, entry),
251 );
252 }
253 }
254
255 for entry in &entries {
256 total_functions += 1;
257 if device.get_func(&module_name, entry).is_some() {
258 resolved_functions += 1;
259 } else {
260 return TestResult::error(
261 "test_kernel_function_resolution",
262 start.elapsed(),
263 format!(
264 "{}: failed to resolve kernel function '{}' from module '{}'",
265 filename, entry, module_name
266 ),
267 );
268 }
269 }
270 }
271
272 if let Err(e) = ctx.sync_and_check() {
273 return TestResult::error(
274 "test_kernel_function_resolution",
275 start.elapsed(),
276 format!("Sync failed after function resolution: {}", e),
277 );
278 }
279
280 if resolved_functions != total_functions {
281 return TestResult::error(
282 "test_kernel_function_resolution",
283 start.elapsed(),
284 format!(
285 "Not all functions resolved: {}/{}",
286 resolved_functions, total_functions
287 ),
288 );
289 }
290
291 TestResult::passed("test_kernel_function_resolution", start.elapsed())
292}
293
294fn test_ptx_module_attributes(ctx: &TestContext) -> TestResult {
300 let start = Instant::now();
301
302 let num_elements = 16;
304 let test_values: Vec<u64> = vec![u64::MAX; num_elements];
305
306 let mut gpu_buffer = match ctx.memory.alloc::<u64>(num_elements) {
308 Ok(buf) => buf,
309 Err(e) => {
310 return TestResult::error(
311 "test_ptx_module_attributes",
312 start.elapsed(),
313 format!("Failed to allocate GPU buffer: {}", e),
314 );
315 }
316 };
317
318 if let Err(e) = ctx
320 .device
321 .inner()
322 .htod_sync_copy_into(&test_values, &mut gpu_buffer)
323 {
324 return TestResult::error(
325 "test_ptx_module_attributes",
326 start.elapsed(),
327 format!("Failed to upload data to GPU: {}", e),
328 );
329 }
330
331 if let Err(e) = ctx.sync_and_check() {
333 return TestResult::error(
334 "test_ptx_module_attributes",
335 start.elapsed(),
336 format!("Sync failed after upload: {}", e),
337 );
338 }
339
340 let downloaded: Vec<u64> = match ctx.device.inner().dtoh_sync_copy(&gpu_buffer) {
342 Ok(data) => data,
343 Err(e) => {
344 return TestResult::error(
345 "test_ptx_module_attributes",
346 start.elapsed(),
347 format!("Failed to download data from GPU: {}", e),
348 );
349 }
350 };
351
352 for (i, &val) in downloaded.iter().enumerate() {
354 if val != u64::MAX {
355 return TestResult::error(
356 "test_ptx_module_attributes",
357 start.elapsed(),
358 format!(
359 "64-bit value mismatch at index {}: expected {}, got {}",
360 i,
361 u64::MAX,
362 val
363 ),
364 );
365 }
366 }
367
368 let patterns: Vec<u64> = vec![
370 0x0000_0000_0000_0001,
371 0x0000_0000_FFFF_FFFF,
372 0xFFFF_FFFF_0000_0000,
373 0x8000_0000_0000_0000,
374 0xDEAD_BEEF_CAFE_BABE,
375 0x0123_4567_89AB_CDEF,
376 ];
377
378 let mut pattern_buffer = match ctx.memory.alloc::<u64>(patterns.len()) {
379 Ok(buf) => buf,
380 Err(e) => {
381 return TestResult::error(
382 "test_ptx_module_attributes",
383 start.elapsed(),
384 format!("Failed to allocate pattern buffer: {}", e),
385 );
386 }
387 };
388
389 if let Err(e) = ctx
390 .device
391 .inner()
392 .htod_sync_copy_into(&patterns, &mut pattern_buffer)
393 {
394 return TestResult::error(
395 "test_ptx_module_attributes",
396 start.elapsed(),
397 format!("Failed to upload patterns: {}", e),
398 );
399 }
400
401 let downloaded_patterns: Vec<u64> = match ctx.device.inner().dtoh_sync_copy(&pattern_buffer) {
402 Ok(data) => data,
403 Err(e) => {
404 return TestResult::error(
405 "test_ptx_module_attributes",
406 start.elapsed(),
407 format!("Failed to download patterns: {}", e),
408 );
409 }
410 };
411
412 for (i, (&expected, &actual)) in patterns.iter().zip(downloaded_patterns.iter()).enumerate() {
413 if expected != actual {
414 return TestResult::error(
415 "test_ptx_module_attributes",
416 start.elapsed(),
417 format!(
418 "64-bit pattern mismatch at index {}: expected 0x{:016X}, got 0x{:016X}",
419 i, expected, actual
420 ),
421 );
422 }
423 }
424
425 if let Err(e) = ctx.sync_and_check() {
426 return TestResult::error(
427 "test_ptx_module_attributes",
428 start.elapsed(),
429 format!("Final sync failed: {}", e),
430 );
431 }
432
433 TestResult::passed("test_ptx_module_attributes", start.elapsed())
434}
435
436fn test_repeated_jit_compilation(ctx: &TestContext) -> TestResult {
442 let start = Instant::now();
443
444 const NUM_ITERATIONS: usize = 15;
445 const BUFFER_SIZE: usize = 1024;
446
447 for iteration in 0..NUM_ITERATIONS {
448 let test_data: Vec<u32> = (0..BUFFER_SIZE)
450 .map(|i| (i + iteration * 1000) as u32)
451 .collect();
452
453 let mut gpu_buffer = match ctx.memory.alloc::<u32>(BUFFER_SIZE) {
455 Ok(buf) => buf,
456 Err(e) => {
457 return TestResult::error(
458 "test_repeated_jit_compilation",
459 start.elapsed(),
460 format!("Allocation failed at iteration {}: {}", iteration, e),
461 );
462 }
463 };
464
465 if let Err(e) = ctx
467 .device
468 .inner()
469 .htod_sync_copy_into(&test_data, &mut gpu_buffer)
470 {
471 return TestResult::error(
472 "test_repeated_jit_compilation",
473 start.elapsed(),
474 format!("Upload failed at iteration {}: {}", iteration, e),
475 );
476 }
477
478 let downloaded: Vec<u32> = match ctx.device.inner().dtoh_sync_copy(&gpu_buffer) {
480 Ok(data) => data,
481 Err(e) => {
482 return TestResult::error(
483 "test_repeated_jit_compilation",
484 start.elapsed(),
485 format!("Download failed at iteration {}: {}", iteration, e),
486 );
487 }
488 };
489
490 for (i, (&expected, &actual)) in test_data.iter().zip(downloaded.iter()).enumerate() {
492 if expected != actual {
493 return TestResult::error(
494 "test_repeated_jit_compilation",
495 start.elapsed(),
496 format!(
497 "Data mismatch at iteration {}, index {}: expected {}, got {}",
498 iteration, i, expected, actual
499 ),
500 );
501 }
502 }
503
504 if let Err(e) = ctx.sync_and_check() {
506 return TestResult::error(
507 "test_repeated_jit_compilation",
508 start.elapsed(),
509 format!("Sync failed at iteration {}: {}", iteration, e),
510 );
511 }
512
513 }
515
516 let device = ctx.device.inner();
518 if device
519 .get_func(JOIN_MODULE, join_kernels::HASH_JOIN_BUILD)
520 .is_none()
521 {
522 return TestResult::error(
523 "test_repeated_jit_compilation",
524 start.elapsed(),
525 "Failed to resolve kernel after repeated iterations".to_string(),
526 );
527 }
528
529 if let Err(e) = ctx.sync_and_check() {
530 return TestResult::error(
531 "test_repeated_jit_compilation",
532 start.elapsed(),
533 format!("Final sync failed: {}", e),
534 );
535 }
536
537 TestResult::passed("test_repeated_jit_compilation", start.elapsed())
538}