pub struct GpuCircuitCache { /* private fields */ }Implementations§
Source§impl GpuCircuitCache
impl GpuCircuitCache
pub fn provider(&self) -> &Arc<CudaKernelProvider>
pub fn var_log_weights_mut( &mut self, ) -> (&mut TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>)
pub fn grad_true(&self) -> &TrackedCudaSlice<f64>
pub fn grad_false(&self) -> &TrackedCudaSlice<f64>
pub fn values(&self) -> &TrackedCudaSlice<f64>
pub fn meta_num_nodes_device(&self) -> &TrackedCudaSlice<u32>
pub fn meta_num_levels_device(&self) -> &TrackedCudaSlice<u32>
pub fn meta_root_device(&self) -> &TrackedCudaSlice<u32>
pub fn meta_max_var_device(&self) -> &TrackedCudaSlice<u32>
pub fn num_slots(&self) -> u32
pub fn new( provider: &Arc<CudaKernelProvider>, config: GpuCircuitCacheConfig, ) -> Result<Self>
pub fn lookup_or_insert(&mut self, key: u64) -> Result<GpuCacheLookup>
pub fn claim_slot(&mut self, key: u64) -> Result<GpuCircuitCacheHandle>
pub fn store_from_xgcf( &mut self, handle: &mut GpuCircuitCacheHandle, xgcf: &GpuXgcf, ) -> Result<()>
pub fn store_weights( &mut self, handle: &GpuCircuitCacheHandle, weights_true: &TrackedCudaSlice<f64>, weights_false: &TrackedCudaSlice<f64>, ) -> Result<()>
pub fn overwrite_weights( &mut self, handle: &GpuCircuitCacheHandle, weights_true: &TrackedCudaSlice<f64>, weights_false: &TrackedCudaSlice<f64>, ) -> Result<()>
pub fn store_free_var_mask( &mut self, handle: &GpuCircuitCacheHandle, mask: &TrackedCudaSlice<u8>, ) -> Result<()>
pub fn eval_log_wmc_device_inplace( &mut self, handle: &GpuCircuitCacheHandle, out_log_z: &mut TrackedCudaSlice<f64>, ) -> Result<()>
pub fn eval_log_wmc_device_only( &mut self, handle: &GpuCircuitCacheHandle, out_log_z: &mut TrackedCudaSlice<f64>, ) -> Result<()>
pub fn eval_grads_inplace( &mut self, handle: &GpuCircuitCacheHandle, ) -> Result<()>
Sourcepub fn eval_grads_inplace_fused(
&mut self,
handle: &GpuCircuitCacheHandle,
) -> Result<()>
pub fn eval_grads_inplace_fused( &mut self, handle: &GpuCircuitCacheHandle, ) -> Result<()>
Like [eval_grads_inplace] but replaces the per-level backward loop
with a single launch of xgcf_backward_all_levels_cached, and omits the
trailing device().synchronize() so that the caller can batch multiple
queries before syncing.
Auto Trait Implementations§
impl Freeze for GpuCircuitCache
impl RefUnwindSafe for GpuCircuitCache
impl Send for GpuCircuitCache
impl Sync for GpuCircuitCache
impl Unpin for GpuCircuitCache
impl UnsafeUnpin for GpuCircuitCache
impl UnwindSafe for GpuCircuitCache
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more