Source code for rate.utils

#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8

# Contains several general-purpose utility functions

import os
import torch
import argparse
from typing import Union

[docs] def set_gpu(gpu: str, frac: float) -> torch.device: """ Function to specify which GPU to use. Args: gpu (str): String label for gpu (i.e. '0'). frac (float): GPU memory fraction (i.e. 0.3 for 30% of the total memory). Returns: torch.device: PyTorch device object for the specified GPU or CPU. """ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = gpu if torch.cuda.is_available(): device = torch.device(f'cuda:{gpu}') # Set memory fraction for PyTorch torch.cuda.set_per_process_memory_fraction(frac, device=int(gpu)) return device else: return torch.device('cpu')
[docs] def restricted_float(x: Union[str, float]) -> float: """ Helper function for restricting input arg to range from 0 to 1. Args: x (Union[str, float]): String or float representing a number. Returns: float: The validated float value. Raises: argparse.ArgumentTypeError: If the value is not in range [0.0, 1.0]. """ x = float(x) if x < 0.0 or x > 1.0: raise argparse.ArgumentTypeError("%r no in range [0.0, 1.0]"%(x,)) return x
[docs] def str2bool(v: Union[str, bool]) -> bool: """ Helper function to parse boolean input args. Args: v (Union[str, bool]): String or boolean representing true or false. Returns: bool: The parsed boolean value. Raises: argparse.ArgumentTypeError: If the value cannot be parsed as a boolean. """ if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.')