Skip to content
Snippets Groups Projects
computeactor_gradients.m 314 B
Newer Older
Ishtiaq Ahmad's avatar
Ishtiaq Ahmad committed
function gradients = computeactor_gradients(network,critic, input_data,loss)
    pred = predict(network, dlarray(input_data,'BC'));
    grad = predict(critic, dlarray(input_data,'BC'),pred);
    scalarValue = grad+sum(loss);
    gradients = dlgradient(dlarray(real(scalarValue(end)),'BC'), network.Learnables);
end