diff --git a/src/fft_base_kernels.h b/src/fft_base_kernels.h index ac883e65fb42600613f43dfb1f8c43c45b73d0fc..d150e238ea20cb39d06c51c6a295a712653aff0a 100644 --- a/src/fft_base_kernels.h +++ b/src/fft_base_kernels.h @@ -71,6 +71,8 @@ static string baseKernels = string( "#ifndef M_PI\n" "#define M_PI 0x1.921fb54442d18p+1\n" "#endif\n" + "#define INTMULFULL(a,b) ((a)*(b)) \n" + "#define INTMADFULL(a,b,c) ((a)*(b)+(c)) \n" "#define complexMul(a,b) ((float2)(mad(-(a).y, (b).y, (a).x * (b).x), mad((a).y, (b).x, (a).x * (b).y)))\n" "\n" "#define cos_sinLUT1(res,dir,i,cossinLUT)\\\n" diff --git a/src/fft_kernelstring.cpp b/src/fft_kernelstring.cpp index 7b493704ac15ec0401b093a49b86de91c8891344..231666deb89947d76a93210985b457ae41b62dfc 100644 --- a/src/fft_kernelstring.cpp +++ b/src/fft_kernelstring.cpp @@ -261,7 +261,7 @@ insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXF kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm-1) + string(";\n"); kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n"); - kernelString += string(" offset = mad24( mad24(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n"); + kernelString += string(" offset = INTMAD( INTMAD(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n"); if(dataFormat == clFFT_InterleavedComplexFormat) { kernelString += string(" in += offset;\n"); @@ -282,7 +282,7 @@ insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXF { kernelString += string(" ii = lId;\n"); kernelString += string(" jj = 0;\n"); - kernelString += string(" offset = mad24(groupId, ") + num2str(N) + string(", ii);\n"); + kernelString += string(" offset = INTMAD(groupId, ") + num2str(N) + string(", ii);\n"); if(dataFormat == clFFT_InterleavedComplexFormat) { kernelString += string(" in += offset;\n"); @@ -306,9 +306,9 @@ insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXF kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n"); kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n"); - kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); - kernelString += string(" offset = mad24( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n"); - kernelString += string(" offset = mad24( offset, ") + num2str(N) + string(", ii );\n"); + kernelString += string(" lMemStore = sMem + INTMAD( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + kernelString += string(" offset = INTMAD( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n"); + kernelString += string(" offset = INTMAD( offset, ") + num2str(N) + string(", ii );\n"); if(dataFormat == clFFT_InterleavedComplexFormat) { kernelString += string(" in += offset;\n"); @@ -343,7 +343,7 @@ insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXF kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n"); kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); - kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n"); + kernelString += string(" lMemLoad = sMem + INTMAD( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n"); for( i = 0; i < numOuterIter; i++ ) { @@ -377,7 +377,7 @@ insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXF } else { - kernelString += string(" offset = mad24( groupId, ") + num2str(N * numXFormsPerWG) + string(", lId );\n"); + kernelString += string(" offset = INTMAD( groupId, ") + num2str(N * numXFormsPerWG) + string(", lId );\n"); if(dataFormat == clFFT_InterleavedComplexFormat) { kernelString += string(" in += offset;\n"); @@ -393,7 +393,7 @@ insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXF kernelString += string(" ii = lId & ") + num2str(N-1) + string(";\n"); kernelString += string(" jj = lId >> ") + num2str((int)log2(N)) + string(";\n"); - kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + kernelString += string(" lMemStore = sMem + INTMAD( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n"); for( i = 0; i < R0; i++ ) @@ -415,13 +415,13 @@ insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXF { kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n"); kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n"); - kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + kernelString += string(" lMemLoad = sMem + INTMAD( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); } else { kernelString += string(" ii = 0;\n"); kernelString += string(" jj = lId;\n"); - kernelString += string(" lMemLoad = sMem + mul24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n"); + kernelString += string(" lMemLoad = sMem + INTMUL( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n"); } @@ -478,10 +478,10 @@ insertGlobalStoresAndTranspose(string &kernelString, int N, int maxRadix, int Nr int numInnerIter = N / mem_coalesce_width; int numOuterIter = numXFormsPerWG / ( groupSize / mem_coalesce_width ); - kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + kernelString += string(" lMemLoad = sMem + INTMAD( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n"); kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n"); - kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + kernelString += string(" lMemStore = sMem + INTMAD( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); for( i = 0; i < maxRadix; i++ ) { @@ -534,11 +534,11 @@ insertGlobalStoresAndTranspose(string &kernelString, int N, int maxRadix, int Nr } else { - kernelString += string(" lMemLoad = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + kernelString += string(" lMemLoad = sMem + INTMAD( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); kernelString += string(" ii = lId & ") + num2str(N - 1) + string(";\n"); kernelString += string(" jj = lId >> ") + num2str((int) log2(N)) + string(";\n"); - kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); + kernelString += string(" lMemStore = sMem + INTMAD( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n"); for( i = 0; i < maxRadix; i++ ) { @@ -730,7 +730,7 @@ insertLocalLoadIndexArithmatic(string &kernelString, int Nprev, int Nr, int numW if(Nprev == 1) kernelString += string(" i = ii >> ") + num2str(logNcurr) + string(";\n"); else - kernelString += string(" i = mad24(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n"); + kernelString += string(" i = INTMAD(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n"); } else { @@ -745,9 +745,9 @@ insertLocalLoadIndexArithmatic(string &kernelString, int Nprev, int Nr, int numW } if(numXFormsPerWG > 1) - kernelString += string(" i = mad24(jj, ") + num2str(incr) + string(", i);\n"); + kernelString += string(" i = INTMAD(jj, ") + num2str(incr) + string(", i);\n"); - kernelString += string(" lMemLoad = sMem + mad24(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n"); + kernelString += string(" lMemLoad = sMem + INTMAD(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n"); } static void @@ -757,7 +757,7 @@ insertLocalStoreIndexArithmatic(string &kernelString, int numWorkItemsReq, int n kernelString += string(" lMemStore = sMem + ii;\n"); } else { - kernelString += string(" lMemStore = sMem + mad24(jj, ") + num2str((numWorkItemsReq + offset)*Nr + midPad) + string(", ii);\n"); + kernelString += string(" lMemStore = sMem + INTMAD(jj, ") + num2str((numWorkItemsReq + offset)*Nr + midPad) + string(", ii);\n"); } } @@ -932,8 +932,8 @@ createLocalMemfftKernelString(cl_fft_plan *plan) void getGlobalRadixInfo(int n, int *radix, int *R1, int *R2, int *numRadices) { - int baseRadix = min(n, 128); - +// int baseRadix = min(n, 128); + int baseRadix = min(n, 128); int numR = 0; int N = n; while(N > baseRadix) @@ -1321,14 +1321,14 @@ createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir { localString += string("xNum = groupId >> ") + num2str((int)log2(numBlocksPerXForm)) + string(";\n"); localString += string("groupId = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n"); - localString += string("indexIn = mad24(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n*BS)) + string(");\n"); - localString += string("tid = mul24(groupId, ") + num2str(batchSize) + string(");\n"); + localString += string("indexIn = INTMAD(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n*BS)) + string(");\n"); + localString += string("tid = INTMUL(groupId, ") + num2str(batchSize) + string(");\n"); localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n"); localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n"); int stride = radix*Rinit; for(i = 0; i < passNum; i++) stride *= radixArr[i]; - localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int) log2(n*BS)) + string("));\n"); + localString += string("indexOut = INTMAD(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int) log2(n*BS)) + string("));\n"); localString += string("bNum = groupId;\n"); } else @@ -1336,14 +1336,14 @@ createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir int lgNumBlocksPerXForm = log2(numBlocksPerXForm); localString += string("bNum = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n"); localString += string("xNum = groupId >> ") + num2str(lgNumBlocksPerXForm) + string(";\n"); - localString += string("indexIn = mul24(bNum, ") + num2str(batchSize) + string(");\n"); + localString += string("indexIn = INTMUL(bNum, ") + num2str(batchSize) + string(");\n"); localString += string("tid = indexIn;\n"); localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n"); localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n"); int stride = radix*Rinit; for(i = 0; i < passNum; i++) stride *= radixArr[i]; - localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j);\n"); + localString += string("indexOut = INTMAD(i, ") + num2str(stride) + string(", j);\n"); localString += string("indexIn += (xNum << ") + num2str(m) + string(");\n"); localString += string("indexOut += (xNum << ") + num2str(m) + string(");\n"); } @@ -1353,7 +1353,7 @@ createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir localString += string("tid = lId;\n"); localString += string("i = tid & ") + num2str(batchSize - 1) + string(";\n"); localString += string("j = tid >> ") + num2str(lgBatchSize) + string(";\n"); - localString += string("indexIn += mad24(j, ") + num2str(strideI) + string(", i);\n"); + localString += string("indexIn += INTMAD(j, ") + num2str(strideI) + string(", i);\n"); if(dataFormat == clFFT_SplitComplexFormat) { @@ -1391,7 +1391,7 @@ createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir // shuffle numIter = R1 / R2; - localString += string("indexIn = mad24(j, ") + num2str(threadsPerBlock*numIter) + string(", i);\n"); + localString += string("indexIn = INTMAD(j, ") + num2str(threadsPerBlock*numIter) + string(", i);\n"); localString += string("lMemStore = sMem + tid;\n"); localString += string("lMemLoad = sMem + indexIn;\n"); for(k = 0; k < R1; k++) @@ -1439,8 +1439,8 @@ createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir if(strideO == 1) { - localString += string("lMemStore = sMem + mad24(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1/R2)) + string(");\n"); - localString += string("lMemLoad = sMem + mad24(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix+1) + string(", tid & ") + num2str(radix-1) + string(");\n"); + localString += string("lMemStore = sMem + INTMAD(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1/R2)) + string(");\n"); + localString += string("lMemLoad = sMem + INTMAD(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix+1) + string(", tid & ") + num2str(radix-1) + string(");\n"); for(int i = 0; i < R1/R2; i++) for(int j = 0; j < R2; j++) @@ -1498,7 +1498,7 @@ createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir } else { - localString += string("indexOut += mad24(j, ") + num2str(numIter*strideO) + string(", i);\n"); + localString += string("indexOut += INTMAD(j, ") + num2str(numIter*strideO) + string(", i);\n"); if(dataFormat == clFFT_SplitComplexFormat) { localString += string("out_real += indexOut;\n"); localString += string("out_imag += indexOut;\n"); diff --git a/src/fft_setup.cpp b/src/fft_setup.cpp index f6f2bbf9f72c974792ee7bac539726fb0d6966cb..3f0f2b17db56453797e2a20f068b17b21b973f97 100644 --- a/src/fft_setup.cpp +++ b/src/fft_setup.cpp @@ -438,7 +438,11 @@ patch_kernel_source: if(device_type == CL_DEVICE_TYPE_GPU) { gpu_found = 1; - err = clBuildProgram(plan->program, 1, &devices[i], "-cl-mad-enable -cl-single-precision-constant", NULL, NULL); + if (plan->n.x * plan->n.y * plan->n.z <= ( 1 << 24)) { + err = clBuildProgram(plan->program, 1, &devices[i], "-cl-mad-enable -cl-single-precision-constant -DINTMUL=mul24 -DINTMAD=mad24", NULL, NULL); + } else { + err = clBuildProgram(plan->program, 1, &devices[i], "-cl-mad-enable -cl-single-precision-constant -DINTMUL=INTMULFULL -DINTMAD=INTMADFULL", NULL, NULL); + } if (err != CL_SUCCESS) { char *build_log;