Source code for ClearMap.ImageProcessing.MachineLearning.Torch

# -*- coding: utf-8 -*-
"""
Torch
=====

Utility functions for PyTorch in ClearMap.
"""
__author__    = 'Christoph Kirst <christoph.kirst.ck@gmail.com>'
__license__   = 'GPLv3 - GNU General Pulic License v3 (see LICENSE.txt)'
__copyright__ = 'Copyright © 2020 by Christoph Kirst'
__webpage__   = 'http://idisco.info'
__download__  = 'http://www.github.com/ChristophKirst/ClearMap2'


import torch

[docs]def to(t, dtype = float): """Convert torch object to a specified data type. Arguments --------- t : torch object The object to convert to a crtian data type. dtype : ['float', 'double', 'float64', 'float32', float16', 'half', float] The data type to use for the torch object. Returns ------- t : torch object The torch object in the requested data type. """ if dtype in ['float', 'double', 'float64', float]: return t.double(); elif dtype in ['float32']: return t.float(); elif dtype in ['float16', 'half']: return t.half(); else: raise ValueError('Data type %r not supported !' % dtype)
[docs]def gpu_info(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') info = 'Device: %r\n' % device if device.type == 'cuda': info += torch.cuda.get_device_name(0) + '\n'; info += 'Memory Usage:\n'; info += 'Allocated: %dGB\n' % round(torch.cuda.memory_allocated(0)/1024**3,1); info += 'Cached: %dGB\n' % round(torch.cuda.memory_cached(0)/1024**3,1); return info;