You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

300 lines
10KB

  1. #ifndef KISSFFT_CLASS_HH
  2. #include <complex>
  3. #include <vector>
  4. namespace kissfft_utils {
  5. template <typename T_scalar>
  6. struct traits
  7. {
  8. typedef T_scalar scalar_type;
  9. typedef std::complex<scalar_type> cpx_type;
  10. void fill_twiddles( std::complex<T_scalar> * dst ,int nfft,bool inverse)
  11. {
  12. T_scalar phinc = (inverse?2:-2)* acos( (T_scalar) -1) / nfft;
  13. for (int i=0;i<nfft;++i)
  14. dst[i] = exp( std::complex<T_scalar>(0,i*phinc) );
  15. }
  16. void prepare(
  17. std::vector< std::complex<T_scalar> > & dst,
  18. int nfft,bool inverse,
  19. std::vector<int> & stageRadix,
  20. std::vector<int> & stageRemainder )
  21. {
  22. _twiddles.resize(nfft);
  23. fill_twiddles( &_twiddles[0],nfft,inverse);
  24. dst = _twiddles;
  25. //factorize
  26. //start factoring out 4's, then 2's, then 3,5,7,9,...
  27. int n= nfft;
  28. int p=4;
  29. do {
  30. while (n % p) {
  31. switch (p) {
  32. case 4: p = 2; break;
  33. case 2: p = 3; break;
  34. default: p += 2; break;
  35. }
  36. if (p*p>n)
  37. p=n;// no more factors
  38. }
  39. n /= p;
  40. stageRadix.push_back(p);
  41. stageRemainder.push_back(n);
  42. }while(n>1);
  43. }
  44. std::vector<cpx_type> _twiddles;
  45. const cpx_type twiddle(int i) { return _twiddles[i]; }
  46. };
  47. }
  48. template <typename T_Scalar,
  49. typename T_traits=kissfft_utils::traits<T_Scalar>
  50. >
  51. class kissfft
  52. {
  53. public:
  54. typedef T_traits traits_type;
  55. typedef typename traits_type::scalar_type scalar_type;
  56. typedef typename traits_type::cpx_type cpx_type;
  57. kissfft(int nfft,bool inverse,const traits_type & traits=traits_type() )
  58. :_nfft(nfft),_inverse(inverse),_traits(traits)
  59. {
  60. _traits.prepare(_twiddles, _nfft,_inverse ,_stageRadix, _stageRemainder);
  61. }
  62. void transform(const cpx_type * src , cpx_type * dst)
  63. {
  64. kf_work(0, dst, src, 1,1);
  65. }
  66. private:
  67. void kf_work( int stage,cpx_type * Fout, const cpx_type * f, size_t fstride,size_t in_stride)
  68. {
  69. int p = _stageRadix[stage];
  70. int m = _stageRemainder[stage];
  71. cpx_type * Fout_beg = Fout;
  72. cpx_type * Fout_end = Fout + p*m;
  73. if (m==1) {
  74. do{
  75. *Fout = *f;
  76. f += fstride*in_stride;
  77. }while(++Fout != Fout_end );
  78. }else{
  79. do{
  80. // recursive call:
  81. // DFT of size m*p performed by doing
  82. // p instances of smaller DFTs of size m,
  83. // each one takes a decimated version of the input
  84. kf_work(stage+1, Fout , f, fstride*p,in_stride);
  85. f += fstride*in_stride;
  86. }while( (Fout += m) != Fout_end );
  87. }
  88. Fout=Fout_beg;
  89. // recombine the p smaller DFTs
  90. switch (p) {
  91. case 2: kf_bfly2(Fout,fstride,m); break;
  92. case 3: kf_bfly3(Fout,fstride,m); break;
  93. case 4: kf_bfly4(Fout,fstride,m); break;
  94. case 5: kf_bfly5(Fout,fstride,m); break;
  95. default: kf_bfly_generic(Fout,fstride,m,p); break;
  96. }
  97. }
  98. // these were #define macros in the original kiss_fft
  99. void C_ADD( cpx_type & c,const cpx_type & a,const cpx_type & b) { c=a+b;}
  100. void C_MUL( cpx_type & c,const cpx_type & a,const cpx_type & b) { c=a*b;}
  101. void C_SUB( cpx_type & c,const cpx_type & a,const cpx_type & b) { c=a-b;}
  102. void C_ADDTO( cpx_type & c,const cpx_type & a) { c+=a;}
  103. void C_FIXDIV( cpx_type & ,int ) {} // NO-OP for float types
  104. scalar_type S_MUL( const scalar_type & a,const scalar_type & b) { return a*b;}
  105. scalar_type HALF_OF( const scalar_type & a) { return a*.5;}
  106. void C_MULBYSCALAR(cpx_type & c,const scalar_type & a) {c*=a;}
  107. void kf_bfly2( cpx_type * Fout, const size_t fstride, int m)
  108. {
  109. for (int k=0;k<m;++k) {
  110. cpx_type t = Fout[m+k] * _traits.twiddle(k*fstride);
  111. Fout[m+k] = Fout[k] - t;
  112. Fout[k] += t;
  113. }
  114. }
  115. void kf_bfly4( cpx_type * Fout, const size_t fstride, const size_t m)
  116. {
  117. cpx_type scratch[7];
  118. int negative_if_inverse = _inverse * -2 +1;
  119. for (size_t k=0;k<m;++k) {
  120. scratch[0] = Fout[k+m] * _traits.twiddle(k*fstride);
  121. scratch[1] = Fout[k+2*m] * _traits.twiddle(k*fstride*2);
  122. scratch[2] = Fout[k+3*m] * _traits.twiddle(k*fstride*3);
  123. scratch[5] = Fout[k] - scratch[1];
  124. Fout[k] += scratch[1];
  125. scratch[3] = scratch[0] + scratch[2];
  126. scratch[4] = scratch[0] - scratch[2];
  127. scratch[4] = cpx_type( scratch[4].imag()*negative_if_inverse , -scratch[4].real()* negative_if_inverse );
  128. Fout[k+2*m] = Fout[k] - scratch[3];
  129. Fout[k] += scratch[3];
  130. Fout[k+m] = scratch[5] + scratch[4];
  131. Fout[k+3*m] = scratch[5] - scratch[4];
  132. }
  133. }
  134. void kf_bfly3( cpx_type * Fout, const size_t fstride, const size_t m)
  135. {
  136. size_t k=m;
  137. const size_t m2 = 2*m;
  138. cpx_type *tw1,*tw2;
  139. cpx_type scratch[5];
  140. cpx_type epi3;
  141. epi3 = _twiddles[fstride*m];
  142. tw1=tw2=&_twiddles[0];
  143. do{
  144. C_FIXDIV(*Fout,3); C_FIXDIV(Fout[m],3); C_FIXDIV(Fout[m2],3);
  145. C_MUL(scratch[1],Fout[m] , *tw1);
  146. C_MUL(scratch[2],Fout[m2] , *tw2);
  147. C_ADD(scratch[3],scratch[1],scratch[2]);
  148. C_SUB(scratch[0],scratch[1],scratch[2]);
  149. tw1 += fstride;
  150. tw2 += fstride*2;
  151. Fout[m] = cpx_type( Fout->real() - HALF_OF(scratch[3].real() ) , Fout->imag() - HALF_OF(scratch[3].imag() ) );
  152. C_MULBYSCALAR( scratch[0] , epi3.imag() );
  153. C_ADDTO(*Fout,scratch[3]);
  154. Fout[m2] = cpx_type( Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() );
  155. C_ADDTO( Fout[m] , cpx_type( -scratch[0].imag(),scratch[0].real() ) );
  156. ++Fout;
  157. }while(--k);
  158. }
  159. void kf_bfly5( cpx_type * Fout, const size_t fstride, const size_t m)
  160. {
  161. cpx_type *Fout0,*Fout1,*Fout2,*Fout3,*Fout4;
  162. size_t u;
  163. cpx_type scratch[13];
  164. cpx_type * twiddles = &_twiddles[0];
  165. cpx_type *tw;
  166. cpx_type ya,yb;
  167. ya = twiddles[fstride*m];
  168. yb = twiddles[fstride*2*m];
  169. Fout0=Fout;
  170. Fout1=Fout0+m;
  171. Fout2=Fout0+2*m;
  172. Fout3=Fout0+3*m;
  173. Fout4=Fout0+4*m;
  174. tw=twiddles;
  175. for ( u=0; u<m; ++u ) {
  176. C_FIXDIV( *Fout0,5); C_FIXDIV( *Fout1,5); C_FIXDIV( *Fout2,5); C_FIXDIV( *Fout3,5); C_FIXDIV( *Fout4,5);
  177. scratch[0] = *Fout0;
  178. C_MUL(scratch[1] ,*Fout1, tw[u*fstride]);
  179. C_MUL(scratch[2] ,*Fout2, tw[2*u*fstride]);
  180. C_MUL(scratch[3] ,*Fout3, tw[3*u*fstride]);
  181. C_MUL(scratch[4] ,*Fout4, tw[4*u*fstride]);
  182. C_ADD( scratch[7],scratch[1],scratch[4]);
  183. C_SUB( scratch[10],scratch[1],scratch[4]);
  184. C_ADD( scratch[8],scratch[2],scratch[3]);
  185. C_SUB( scratch[9],scratch[2],scratch[3]);
  186. C_ADDTO( *Fout0, scratch[7]);
  187. C_ADDTO( *Fout0, scratch[8]);
  188. scratch[5] = scratch[0] + cpx_type(
  189. S_MUL(scratch[7].real(),ya.real() ) + S_MUL(scratch[8].real() ,yb.real() ),
  190. S_MUL(scratch[7].imag(),ya.real()) + S_MUL(scratch[8].imag(),yb.real())
  191. );
  192. scratch[6] = cpx_type(
  193. S_MUL(scratch[10].imag(),ya.imag()) + S_MUL(scratch[9].imag(),yb.imag()),
  194. -S_MUL(scratch[10].real(),ya.imag()) - S_MUL(scratch[9].real(),yb.imag())
  195. );
  196. C_SUB(*Fout1,scratch[5],scratch[6]);
  197. C_ADD(*Fout4,scratch[5],scratch[6]);
  198. scratch[11] = scratch[0] +
  199. cpx_type(
  200. S_MUL(scratch[7].real(),yb.real()) + S_MUL(scratch[8].real(),ya.real()),
  201. S_MUL(scratch[7].imag(),yb.real()) + S_MUL(scratch[8].imag(),ya.real())
  202. );
  203. scratch[12] = cpx_type(
  204. -S_MUL(scratch[10].imag(),yb.imag()) + S_MUL(scratch[9].imag(),ya.imag()),
  205. S_MUL(scratch[10].real(),yb.imag()) - S_MUL(scratch[9].real(),ya.imag())
  206. );
  207. C_ADD(*Fout2,scratch[11],scratch[12]);
  208. C_SUB(*Fout3,scratch[11],scratch[12]);
  209. ++Fout0;++Fout1;++Fout2;++Fout3;++Fout4;
  210. }
  211. }
  212. /* perform the butterfly for one stage of a mixed radix FFT */
  213. void kf_bfly_generic(
  214. cpx_type * Fout,
  215. const size_t fstride,
  216. int m,
  217. int p
  218. )
  219. {
  220. int u,k,q1,q;
  221. cpx_type * twiddles = &_twiddles[0];
  222. cpx_type t;
  223. int Norig = _nfft;
  224. cpx_type scratchbuf[p];
  225. for ( u=0; u<m; ++u ) {
  226. k=u;
  227. for ( q1=0 ; q1<p ; ++q1 ) {
  228. scratchbuf[q1] = Fout[ k ];
  229. C_FIXDIV(scratchbuf[q1],p);
  230. k += m;
  231. }
  232. k=u;
  233. for ( q1=0 ; q1<p ; ++q1 ) {
  234. int twidx=0;
  235. Fout[ k ] = scratchbuf[0];
  236. for (q=1;q<p;++q ) {
  237. twidx += fstride * k;
  238. if (twidx>=Norig) twidx-=Norig;
  239. C_MUL(t,scratchbuf[q] , twiddles[twidx] );
  240. C_ADDTO( Fout[ k ] ,t);
  241. }
  242. k += m;
  243. }
  244. }
  245. }
  246. int _nfft;
  247. bool _inverse;
  248. std::vector<cpx_type> _twiddles;
  249. std::vector<int> _stageRadix;
  250. std::vector<int> _stageRemainder;
  251. traits_type _traits;
  252. };
  253. #endif