SIR¤
SIR
¤
Bases: Model
Source code in blackbirds/models/sir.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
__init__(graph, n_timesteps, device='cpu')
¤
Implements a differentiable SIR model on a graph.
Arguments:
graph
: a networkx graphn_timesteps
: the number of timesteps to run the model fordevice
: device to use (eg. "cpu" or "cuda:0")
Source code in blackbirds/models/sir.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
|
initialize(params)
¤
Initializes the model setting the adequate number of initial infections.
Arguments:
- params: a tensor of shape (3,) containing the log10 of the fraction of infected, beta, and gamma
Source code in blackbirds/models/sir.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
|
observe(x)
¤
Returns the total number of infected and recovered agents per time-step
Arguments:
- x: a tensor of shape (3, n_agents) containing the infected, susceptible, and recovered counts.
Source code in blackbirds/models/sir.py
97 98 99 100 101 102 103 104 105 |
|
sample_bernoulli_gs(probs, tau=0.1)
¤
Samples from a Bernoulli distribution in a diferentiable way using Gumble-Softmax
Arguments:
- probs: a tensor of shape (n,) containing the probabilities of success for each trial
- tau: the temperature of the Gumble-Softmax distribution
Source code in blackbirds/models/sir.py
26 27 28 29 30 31 32 33 34 35 36 37 |
|
step(params, x)
¤
Runs the model forward for one timestep.
Arguments:
- params: a tensor of shape (3,) containing the log10 of the fraction of infected, beta, and gamma
- x: a tensor of shape (3, n_agents) containing the infected, susceptible, and recovered counts.
Source code in blackbirds/models/sir.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
|
SIRMessagePassing
¤
Bases: MessagePassing
Class used to pass messages between agents about their infected status.
Source code in blackbirds/models/sir.py
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
|
forward(edge_index, infected, susceptible)
¤
Computes the sum of the product between the node's susceptibility and the neighbors' infected status.
Arguments:
- edge_index: a tensor of shape (2, n_edges) containing the edge indices
- infected: a tensor of shape (n_nodes,) containing the infected status of each node
- susceptible: a tensor of shape (n_nodes,) containing the susceptible status of each node
Source code in blackbirds/models/sir.py
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
|