self.activation = nn.ReLU() # or nn.ReLU()
# Correctly initialize the LSTM layer with bidirectional=False
self.lstm = PyroModule[nn.LSTM](input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
self.linear = PyroModule[nn.Linear](hidden_size, 128) # Adjusted for unidirectional LSTM
self.fc = PyroModule[nn.Linear](128, num_classes)
# Initialize weights and biases for each layer
# Input to hidden layer
self.lstm.weight_ih_l0 = PyroSample(dist.Normal(0., prior_scale).expand([4*hidden_size, input_size]).to_event(2))
self.lstm.bias_ih_l0 = PyroSample(dist.Normal(0., prior_scale).expand([4*hidden_size]).to_event(1))
# Hidden to hidden layer
self.lstm.weight_hh_l0 = PyroSample(dist.Normal(0., prior_scale).expand([4*hidden_size, hidden_size]).to_event(2))
self.lstm.bias_hh_l0= PyroSample(dist.Normal(0., prior_scale).expand([4*hidden_size]).to_event(1))
self.linear.weight = PyroSample(dist.Normal(0., prior_scale).expand([128,hidden_size]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., prior_scale).expand([128]).to_event(1))
self.fc.weight = PyroSample(dist.Normal(0., prior_scale).expand([num_classes, 128]).to_event(2))
self.fc.bias = PyroSample(dist.Normal(0., prior_scale).expand([num_classes]).to_event(1))
def forward(self, x, y=None,noise_shape = 0.5):
h_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) #hidden state
c_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) #internal state
output, (hn, cn) = self.lstm(x, (h_0, c_0)) #lstm with input, hidden, and internal state
hn = hn.view(-1, self.hidden_size) #reshaping the data for Dense layer next
out = self.activation(hn)
out = self.linear(out)
out = self.activation(out)
mu = self.fc(out)
sigma = pyro.sample("sigma", dist.Gamma(noise_shape, 1)) # infer the response noise
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mu, sigma * sigma), obs=y)
return mu