Module TeachMyAgent.teachers.utils.alpha_functions
Expand source code
# Taken from https://github.com/psclklnk/spdl
# Copy of the license at TeachMyAgent/teachers/LICENSES/SPDL
from abc import ABC, abstractmethod
import torch
class AlphaFunction(ABC):
@abstractmethod
def __call__(self, iteration, average_reward, kl_divergence):
pass
class PercentageAlphaFunction(AlphaFunction):
def __init__(self, offset, percentage):
'''
Calculate an automatically adjusted alpha parameter to maintain constant proportion.
:param percentage: proportion to maintain
:param offset: How many times alpha should be set to 0
'''
self.offset = offset
self.percentage = percentage
def __call__(self, iteration, average_reward, kl_divergence):
if iteration < self.offset:
alpha = 0.
else:
kl_divergence = torch.clamp(kl_divergence, min=1e-10)
average_reward = 0. if average_reward < 0. else average_reward
alpha = torch.clamp(self.percentage * average_reward / kl_divergence, max=1e5)
return alpha
Classes
class AlphaFunction
-
Helper class that provides a standard way to create an ABC using inheritance.
Expand source code
class AlphaFunction(ABC): @abstractmethod def __call__(self, iteration, average_reward, kl_divergence): pass
Ancestors
- abc.ABC
Subclasses
class PercentageAlphaFunction (offset, percentage)
-
Helper class that provides a standard way to create an ABC using inheritance.
Calculate an automatically adjusted alpha parameter to maintain constant proportion.
:param percentage: proportion to maintain :param offset: How many times alpha should be set to 0
Expand source code
class PercentageAlphaFunction(AlphaFunction): def __init__(self, offset, percentage): ''' Calculate an automatically adjusted alpha parameter to maintain constant proportion. :param percentage: proportion to maintain :param offset: How many times alpha should be set to 0 ''' self.offset = offset self.percentage = percentage def __call__(self, iteration, average_reward, kl_divergence): if iteration < self.offset: alpha = 0. else: kl_divergence = torch.clamp(kl_divergence, min=1e-10) average_reward = 0. if average_reward < 0. else average_reward alpha = torch.clamp(self.percentage * average_reward / kl_divergence, max=1e5) return alpha
Ancestors
- AlphaFunction
- abc.ABC