You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
3.7KB

  1. /*
  2. * Copyright (c) 2020
  3. *
  4. * This file is part of FFmpeg.
  5. *
  6. * FFmpeg is free software; you can redistribute it and/or
  7. * modify it under the terms of the GNU Lesser General Public
  8. * License as published by the Free Software Foundation; either
  9. * version 2.1 of the License, or (at your option) any later version.
  10. *
  11. * FFmpeg is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  14. * Lesser General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU Lesser General Public
  17. * License along with FFmpeg; if not, write to the Free Software
  18. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  19. */
  20. #include <stdio.h>
  21. #include <string.h>
  22. #include <math.h>
  23. #include "libavfilter/dnn/dnn_backend_native_layer_mathunary.h"
  24. #include "libavutil/avassert.h"
  25. #define EPS 0.00001
  26. static float get_expected(float f, DNNMathUnaryOperation op)
  27. {
  28. switch (op)
  29. {
  30. case DMUO_ABS:
  31. return (f >= 0) ? f : -f;
  32. case DMUO_SIN:
  33. return sin(f);
  34. case DMUO_COS:
  35. return cos(f);
  36. case DMUO_TAN:
  37. return tan(f);
  38. case DMUO_ASIN:
  39. return asin(f);
  40. case DMUO_ACOS:
  41. return acos(f);
  42. case DMUO_ATAN:
  43. return atan(f);
  44. case DMUO_SINH:
  45. return sinh(f);
  46. case DMUO_COSH:
  47. return cosh(f);
  48. case DMUO_TANH:
  49. return tanh(f);
  50. case DMUO_ASINH:
  51. return asinh(f);
  52. case DMUO_ACOSH:
  53. return acosh(f);
  54. case DMUO_ATANH:
  55. return atanh(f);
  56. case DMUO_CEIL:
  57. return ceil(f);
  58. case DMUO_FLOOR:
  59. return floor(f);
  60. case DMUO_ROUND:
  61. return round(f);
  62. default:
  63. av_assert0(!"not supported yet");
  64. return 0.f;
  65. }
  66. }
  67. static int test(DNNMathUnaryOperation op)
  68. {
  69. DnnLayerMathUnaryParams params;
  70. DnnOperand operands[2];
  71. int32_t input_indexes[1];
  72. float input[1*1*3*3] = {
  73. 0.1, 0.5, 0.75, -3, 2.5, 2, -2.1, 7.8, 100};
  74. float *output;
  75. params.un_op = op;
  76. operands[0].data = input;
  77. operands[0].dims[0] = 1;
  78. operands[0].dims[1] = 1;
  79. operands[0].dims[2] = 3;
  80. operands[0].dims[3] = 3;
  81. operands[1].data = NULL;
  82. input_indexes[0] = 0;
  83. ff_dnn_execute_layer_math_unary(operands, input_indexes, 1, &params, NULL);
  84. output = operands[1].data;
  85. for (int i = 0; i < sizeof(input) / sizeof(float); ++i) {
  86. float expected_output = get_expected(input[i], op);
  87. int output_nan = isnan(output[i]);
  88. int expected_nan = isnan(expected_output);
  89. if ((!output_nan && !expected_nan && fabs(output[i] - expected_output) > EPS) ||
  90. (output_nan && !expected_nan) || (!output_nan && expected_nan)) {
  91. printf("at index %d, output: %f, expected_output: %f\n", i, output[i], expected_output);
  92. av_freep(&output);
  93. return 1;
  94. }
  95. }
  96. av_freep(&output);
  97. return 0;
  98. }
  99. int main(int agrc, char **argv)
  100. {
  101. if (test(DMUO_ABS))
  102. return 1;
  103. if (test(DMUO_SIN))
  104. return 1;
  105. if (test(DMUO_COS))
  106. return 1;
  107. if (test(DMUO_TAN))
  108. return 1;
  109. if (test(DMUO_ASIN))
  110. return 1;
  111. if (test(DMUO_ACOS))
  112. return 1;
  113. if (test(DMUO_ATAN))
  114. return 1;
  115. if (test(DMUO_SINH))
  116. return 1;
  117. if (test(DMUO_COSH))
  118. return 1;
  119. if (test(DMUO_TANH))
  120. return 1;
  121. if (test(DMUO_ASINH))
  122. return 1;
  123. if (test(DMUO_ACOSH))
  124. return 1;
  125. if (test(DMUO_ATANH))
  126. return 1;
  127. if (test(DMUO_CEIL))
  128. return 1;
  129. if (test(DMUO_FLOOR))
  130. return 1;
  131. if (test(DMUO_ROUND))
  132. return 1;
  133. return 0;
  134. }