Skip to content
Snippets Groups Projects
Commit ad9c3d89 authored by Petr Šádek's avatar Petr Šádek
Browse files

cuda torch interop progressing

parent fc98b1a2
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@
#include <CudaImpl/CudaUtils.hpp>
#include <CudaImpl/CudaTest.hpp>
#include <MorphImpl/CudaTorchTest.hpp>
#include <TorchImpl/Utils.hpp>
int main(void)
{
......@@ -22,7 +23,8 @@ int main(void)
&CudaDeviceSynchronize
},
&imageToFloatArray,
&floatArrayToImage
&floatArrayToImage,
&InvertTensor
);
return 0;
......
......@@ -42,6 +42,7 @@ private:
Morph::MorphImpl::CudaFuncs m_cuda;
ImageToFloatArrayFun m_imageToArray;
FloatArrayToImageFun m_arrayToImage;
InvertTensorFun m_invertTensorFun;
unique<WindowManager> m_windowManager;
// Events
MethodAttacher<WindowManagerError, CudaTorchTestApplication> m_windowManagerErrorAttacher;
......@@ -72,8 +73,15 @@ private:
void* m_cudaTextureRes = nullptr;
void* m_cudaArray = nullptr;
public:
CudaTorchTestApplication(Morph::MorphImpl::CudaFuncs cuda, ImageToFloatArrayFun imageToArray, FloatArrayToImageFun arrayToImage)
: m_cuda(cuda), m_imageToArray(imageToArray), m_arrayToImage(arrayToImage)
CudaTorchTestApplication(
Morph::MorphImpl::CudaFuncs cuda,
ImageToFloatArrayFun imageToArray,
FloatArrayToImageFun arrayToImage,
InvertTensorFun invertTensorFun)
: m_cuda(cuda),
m_imageToArray(imageToArray),
m_arrayToImage(arrayToImage),
m_invertTensorFun(invertTensorFun)
{
m_windowManager = unique<WindowManager>(WindowManager::Get());
......@@ -206,6 +214,9 @@ public:
// use cuda
{
m_imageToArray(m_cudaTextureRes, m_cudaArray, sceneTexture.dim().x, sceneTexture.dim().y);
m_cuda.DeviceSynchronize();
m_invertTensorFun(m_cudaArray, sceneTexture.dim().x, sceneTexture.dim().y, 3);
m_cuda.DeviceSynchronize();
m_arrayToImage(m_cudaTextureRes, m_cudaArray, sceneTexture.dim().x, sceneTexture.dim().y);
}
......@@ -266,9 +277,13 @@ private:
namespace MorphImpl {
void CudaTorchTest(CudaFuncs cudaFuncs, ImageToFloatArrayFun imageToArray, FloatArrayToImageFun arrayToImage) {
void CudaTorchTest(
CudaFuncs cudaFuncs,
ImageToFloatArrayFun imageToArray,
FloatArrayToImageFun arrayToImage,
InvertTensorFun invertTensorFun) {
Morph::CudaTorchTestApplication app(cudaFuncs, imageToArray, arrayToImage);
Morph::CudaTorchTestApplication app(cudaFuncs, imageToArray, arrayToImage, invertTensorFun);
app.Run();
}
......
......@@ -7,10 +7,16 @@
typedef void (*ImageToFloatArrayFun)(void*&, void*&, unsigned int, unsigned int);
typedef void (*FloatArrayToImageFun)(void*&, void*&, unsigned int, unsigned int);
typedef void (*InvertTensorFun)(void*, int, int, int);
namespace Morph { namespace MorphImpl {
void CudaTorchTest(CudaFuncs cudaFuncs, ImageToFloatArrayFun imageToArray, FloatArrayToImageFun arrayToImage);
void CudaTorchTest(
CudaFuncs cudaFuncs,
ImageToFloatArrayFun imageToArray,
FloatArrayToImageFun arrayToImage,
InvertTensorFun invertTensorFun
);
}}
......
......@@ -76,3 +76,10 @@ void EvalImage(std::string model, std::string in_location, std::string out_locat
WriteTensorTo_PNG(texture, out_location);
std::cout << "texture saved" << std::endl;
}
void InvertTensor(void* data, int x, int y, int z)
{
torch::Tensor t_data = torch::from_blob((float*)data, {x, y, z});
t_data.packed_accessor64<float,2>();
t_data = 1 - t_data;
}
\ No newline at end of file
......@@ -5,6 +5,6 @@
void EvalImage(std::string model, std::string in_location, std::string out_location);
void InvertTensor(void* data, int x, int y, int z);
#endif // STYLE_TRANSFER_TORCH_IMPL_HPP
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment