Skip to content
Snippets Groups Projects
critic_gradients.m 413 B
Newer Older
Ramsha Narmeen's avatar
Ramsha Narmeen committed
function gradients = critic_gradients(network, input_data1, input_data2,target_Q_values) 
    lambda = 0.99;
    predictions = predict(network, input_data1, input_data2);
    target_Q_values = reshape(target_Q_values,length(predictions),[]);
    loss = lambda*crossentropy(dlarray((mean(target_Q_values,2)),'BC'),predictions);
    loss = real(sum(loss));
    gradients = dlgradient(-loss, network.Learnables);
end