Skip to content

Example losses¤

SingleOutput_SimulateAndMMD ¤

Example implementation of a loss that simulates from the model and computes the MMD between the model output and observed data y. (This treats the entries in y and in the simulator output as exchangeable.)

Arguments:

  • y: torch.Tensor containing a single univariate time series.
  • model: An instance of a Model.
  • gradient_horizon: An integer or None. Sets horizon over which gradients are retained. If None, infinite horizon used.
Source code in blackbirds/losses.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class SingleOutput_SimulateAndMMD:
    """
    Example implementation of a loss that simulates from the model and computes the MMD
    between the model output and observed data y. (This treats the entries in y and in
    the simulator output as exchangeable.)

    **Arguments:**

    - `y`: torch.Tensor containing a single univariate time series.
    - `model`: An instance of a Model.
    - `gradient_horizon`: An integer or None. Sets horizon over which gradients are retained. If None, infinite horizon used.
    """
    def __init__(
        self, y: torch.Tensor, model: Model, gradient_horizon: Union[int, None] = None
    ):
        self.mmd_loss = UnivariateMMDLoss(y)
        self.model = model
        self.gradient_horizon = gradient_horizon

    def __call__(self, theta: torch.Tensor, y: torch.Tensor):
        x = simulate_and_observe_model(self.model, theta, self.gradient_horizon)[0]
        return self.mmd_loss(x)

SingleOutput_SimulateAndMSELoss ¤

Computes MSE between observed data y and simulated data at theta (to be passed during call).

Arguments:

  • model: An instance of a Model. The model that you'd like to "fit".
  • gradient_horizon: Specifies the gradient horizon to use. None implies infinite horizon.
Source code in blackbirds/losses.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class SingleOutput_SimulateAndMSELoss:

    """
    Computes MSE between observed data y and simulated data at theta (to be passed during __call__).

    **Arguments:**

    - `model`: An instance of a Model. The model that you'd like to "fit".
    - `gradient_horizon`: Specifies the gradient horizon to use. None implies infinite horizon.
    """

    def __init__(self, model: Model, gradient_horizon: Union[int, None] = None):
        self.loss = torch.nn.MSELoss()
        self.model = model
        self.gradient_horizon = gradient_horizon

    def __call__(
        self,
        theta: torch.Tensor,
        y: torch.Tensor,
    ):
        x = simulate_and_observe_model(self.model, theta, self.gradient_horizon)[0]
        return self.loss(x, y)

UnivariateMMDLoss ¤

Source code in blackbirds/losses.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class UnivariateMMDLoss:
    def __init__(self, y: torch.Tensor):
        """
        Computes MMD between data y and simulated output x (to be passed during call).

        **Arguments:**

        - `y`: torch.Tensor containing a single univariate time series.
        """
        assert isinstance(y, torch.Tensor), "y is assumed to be a torch.Tensor here"
        try:
            assert (
                len(y.shape) == 1
            ), "This class assumes y is a single univariate time series"
        except AssertionError:
            assert (
                len(y.shape) == 2
            ), "If not a 1D Tensor, y must be at most 2D of shape (1, T)"
            assert (
                y.shape[1] == 1
            ), "This class assumes y is a single univariate time series. This appears to be a batch of data."
            y = y.reshape(-1)
        self.device = y.device
        self.y = y
        self.y_matrix = self.y.reshape(1, -1, 1)
        yy = torch.cdist(self.y_matrix, self.y_matrix)
        yy_sqrd = torch.pow(yy, 2)
        self.y_sigma = torch.median(yy_sqrd)
        ny = self.y.shape[0]
        self.kyy = (
            torch.exp(-yy_sqrd / self.y_sigma) - torch.eye(ny, device=self.device)
        ).sum() / (ny * (ny - 1))

    def __call__(
        self,
        x: torch.Tensor,
    ):
        assert isinstance(x, torch.Tensor), "x is assumed to be a torch.Tensor here"
        try:
            assert (
                len(x.shape) == 1
            ), "This class assumes x is a single univariate time series"
        except AssertionError:
            assert (
                len(x.shape) == 2
            ), "If not a 1D Tensor, x must be at most 2D of shape (1, T)"
            assert (
                x.shape[1] == 1
            ), "This class assumes x is a single univariate time series. This appears to be a batch of data."
            x = x.reshape(-1)

        nx = x.shape[0]
        x_matrix = x.reshape(1, -1, 1)
        kxx = torch.exp(-torch.pow(torch.cdist(x_matrix, x_matrix), 2) / self.y_sigma)
        kxx = (kxx - torch.eye(nx, device=self.device)).sum() / (nx * (nx - 1))
        kxy = torch.exp(
            -torch.pow(torch.cdist(x_matrix, self.y_matrix), 2) / self.y_sigma
        )
        kxy = kxy.mean()
        return kxx + self.kyy - 2 * kxy

__init__(y) ¤

Computes MMD between data y and simulated output x (to be passed during call).

Arguments:

  • y: torch.Tensor containing a single univariate time series.
Source code in blackbirds/losses.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(self, y: torch.Tensor):
    """
    Computes MMD between data y and simulated output x (to be passed during call).

    **Arguments:**

    - `y`: torch.Tensor containing a single univariate time series.
    """
    assert isinstance(y, torch.Tensor), "y is assumed to be a torch.Tensor here"
    try:
        assert (
            len(y.shape) == 1
        ), "This class assumes y is a single univariate time series"
    except AssertionError:
        assert (
            len(y.shape) == 2
        ), "If not a 1D Tensor, y must be at most 2D of shape (1, T)"
        assert (
            y.shape[1] == 1
        ), "This class assumes y is a single univariate time series. This appears to be a batch of data."
        y = y.reshape(-1)
    self.device = y.device
    self.y = y
    self.y_matrix = self.y.reshape(1, -1, 1)
    yy = torch.cdist(self.y_matrix, self.y_matrix)
    yy_sqrd = torch.pow(yy, 2)
    self.y_sigma = torch.median(yy_sqrd)
    ny = self.y.shape[0]
    self.kyy = (
        torch.exp(-yy_sqrd / self.y_sigma) - torch.eye(ny, device=self.device)
    ).sum() / (ny * (ny - 1))