Newer
Older
function gradients = critic_gradients(network, input_data1, input_data2,target_Q_values)
lambda = 0.99;
predictions = predict(network, input_data1, input_data2);
target_Q_values = reshape(target_Q_values,length(predictions),[]);
loss = lambda*crossentropy(dlarray((mean(target_Q_values,2)),'BC'),predictions);
loss = real(sum(loss));
gradients = dlgradient(-loss, network.Learnables);
end