Skip to main content

GpuCircuitCache

Struct GpuCircuitCache 

Source
pub struct GpuCircuitCache { /* private fields */ }

Implementations§

Source§

impl GpuCircuitCache

Source

pub fn provider(&self) -> &Arc<CudaKernelProvider>

Source

pub fn var_log_weights_mut( &mut self, ) -> (&mut TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>)

Source

pub fn grad_true(&self) -> &TrackedCudaSlice<f64>

Source

pub fn grad_false(&self) -> &TrackedCudaSlice<f64>

Source

pub fn values(&self) -> &TrackedCudaSlice<f64>

Source

pub fn meta_num_nodes_device(&self) -> &TrackedCudaSlice<u32>

Source

pub fn meta_num_levels_device(&self) -> &TrackedCudaSlice<u32>

Source

pub fn meta_root_device(&self) -> &TrackedCudaSlice<u32>

Source

pub fn meta_max_var_device(&self) -> &TrackedCudaSlice<u32>

Source

pub fn num_slots(&self) -> u32

Source

pub fn new( provider: &Arc<CudaKernelProvider>, config: GpuCircuitCacheConfig, ) -> Result<Self>

Source

pub fn lookup_or_insert(&mut self, key: u64) -> Result<GpuCacheLookup>

Source

pub fn claim_slot(&mut self, key: u64) -> Result<GpuCircuitCacheHandle>

Source

pub fn store_from_xgcf( &mut self, handle: &mut GpuCircuitCacheHandle, xgcf: &GpuXgcf, ) -> Result<()>

Source

pub fn store_weights( &mut self, handle: &GpuCircuitCacheHandle, weights_true: &TrackedCudaSlice<f64>, weights_false: &TrackedCudaSlice<f64>, ) -> Result<()>

Source

pub fn overwrite_weights( &mut self, handle: &GpuCircuitCacheHandle, weights_true: &TrackedCudaSlice<f64>, weights_false: &TrackedCudaSlice<f64>, ) -> Result<()>

Source

pub fn store_free_var_mask( &mut self, handle: &GpuCircuitCacheHandle, mask: &TrackedCudaSlice<u8>, ) -> Result<()>

Source

pub fn eval_log_wmc_device_inplace( &mut self, handle: &GpuCircuitCacheHandle, out_log_z: &mut TrackedCudaSlice<f64>, ) -> Result<()>

Source

pub fn eval_log_wmc_device_only( &mut self, handle: &GpuCircuitCacheHandle, out_log_z: &mut TrackedCudaSlice<f64>, ) -> Result<()>

Source

pub fn eval_grads_inplace( &mut self, handle: &GpuCircuitCacheHandle, ) -> Result<()>

Source

pub fn eval_grads_inplace_fused( &mut self, handle: &GpuCircuitCacheHandle, ) -> Result<()>

Like [eval_grads_inplace] but replaces the per-level backward loop with a single launch of xgcf_backward_all_levels_cached, and omits the trailing device().synchronize() so that the caller can batch multiple queries before syncing.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<T> Allocation for T
where T: RefUnwindSafe + Send + Sync,