You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

195 lines
5.1KB

  1. /*
  2. * Copyright (c) 2020
  3. *
  4. * This file is part of FFmpeg.
  5. *
  6. * FFmpeg is free software; you can redistribute it and/or
  7. * modify it under the terms of the GNU Lesser General Public
  8. * License as published by the Free Software Foundation; either
  9. * version 2.1 of the License, or (at your option) any later version.
  10. *
  11. * FFmpeg is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  14. * Lesser General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU Lesser General Public
  17. * License along with FFmpeg; if not, write to the Free Software
  18. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  19. */
  20. #include <stdio.h>
  21. #include <string.h>
  22. #include <math.h>
  23. #include "libavfilter/dnn/dnn_backend_native_layer_mathbinary.h"
  24. #include "libavutil/avassert.h"
  25. #define EPSON 0.00001
  26. static float get_expected(float f1, float f2, DNNMathBinaryOperation op)
  27. {
  28. switch (op)
  29. {
  30. case DMBO_SUB:
  31. return f1 - f2;
  32. case DMBO_ADD:
  33. return f1 + f2;
  34. default:
  35. av_assert0(!"not supported yet");
  36. return 0.f;
  37. }
  38. }
  39. static int test_broadcast_input0(DNNMathBinaryOperation op)
  40. {
  41. DnnLayerMathBinaryParams params;
  42. DnnOperand operands[2];
  43. int32_t input_indexes[1];
  44. float input[1*1*2*3] = {
  45. -3, 2.5, 2, -2.1, 7.8, 100
  46. };
  47. float *output;
  48. params.bin_op = op;
  49. params.input0_broadcast = 1;
  50. params.input1_broadcast = 0;
  51. params.v = 7.28;
  52. operands[0].data = input;
  53. operands[0].dims[0] = 1;
  54. operands[0].dims[1] = 1;
  55. operands[0].dims[2] = 2;
  56. operands[0].dims[3] = 3;
  57. operands[1].data = NULL;
  58. input_indexes[0] = 0;
  59. dnn_execute_layer_math_binary(operands, input_indexes, 1, &params);
  60. output = operands[1].data;
  61. for (int i = 0; i < sizeof(input) / sizeof(float); i++) {
  62. float expected_output = get_expected(params.v, input[i], op);
  63. if (fabs(output[i] - expected_output) > EPSON) {
  64. printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n",
  65. op, i, output[i], expected_output, __FILE__, __LINE__);
  66. av_freep(&output);
  67. return 1;
  68. }
  69. }
  70. av_freep(&output);
  71. return 0;
  72. }
  73. static int test_broadcast_input1(DNNMathBinaryOperation op)
  74. {
  75. DnnLayerMathBinaryParams params;
  76. DnnOperand operands[2];
  77. int32_t input_indexes[1];
  78. float input[1*1*2*3] = {
  79. -3, 2.5, 2, -2.1, 7.8, 100
  80. };
  81. float *output;
  82. params.bin_op = op;
  83. params.input0_broadcast = 0;
  84. params.input1_broadcast = 1;
  85. params.v = 7.28;
  86. operands[0].data = input;
  87. operands[0].dims[0] = 1;
  88. operands[0].dims[1] = 1;
  89. operands[0].dims[2] = 2;
  90. operands[0].dims[3] = 3;
  91. operands[1].data = NULL;
  92. input_indexes[0] = 0;
  93. dnn_execute_layer_math_binary(operands, input_indexes, 1, &params);
  94. output = operands[1].data;
  95. for (int i = 0; i < sizeof(input) / sizeof(float); i++) {
  96. float expected_output = get_expected(input[i], params.v, op);
  97. if (fabs(output[i] - expected_output) > EPSON) {
  98. printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n",
  99. op, i, output[i], expected_output, __FILE__, __LINE__);
  100. av_freep(&output);
  101. return 1;
  102. }
  103. }
  104. av_freep(&output);
  105. return 0;
  106. }
  107. static int test_no_broadcast(DNNMathBinaryOperation op)
  108. {
  109. DnnLayerMathBinaryParams params;
  110. DnnOperand operands[3];
  111. int32_t input_indexes[2];
  112. float input0[1*1*2*3] = {
  113. -3, 2.5, 2, -2.1, 7.8, 100
  114. };
  115. float input1[1*1*2*3] = {
  116. -1, 2, 3, -21, 8, 10.0
  117. };
  118. float *output;
  119. params.bin_op = op;
  120. params.input0_broadcast = 0;
  121. params.input1_broadcast = 0;
  122. operands[0].data = input0;
  123. operands[0].dims[0] = 1;
  124. operands[0].dims[1] = 1;
  125. operands[0].dims[2] = 2;
  126. operands[0].dims[3] = 3;
  127. operands[1].data = input1;
  128. operands[1].dims[0] = 1;
  129. operands[1].dims[1] = 1;
  130. operands[1].dims[2] = 2;
  131. operands[1].dims[3] = 3;
  132. operands[2].data = NULL;
  133. input_indexes[0] = 0;
  134. input_indexes[1] = 1;
  135. dnn_execute_layer_math_binary(operands, input_indexes, 2, &params);
  136. output = operands[2].data;
  137. for (int i = 0; i < sizeof(input0) / sizeof(float); i++) {
  138. float expected_output = get_expected(input0[i], input1[i], op);
  139. if (fabs(output[i] - expected_output) > EPSON) {
  140. printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n",
  141. op, i, output[i], expected_output, __FILE__, __LINE__);
  142. av_freep(&output);
  143. return 1;
  144. }
  145. }
  146. av_freep(&output);
  147. return 0;
  148. }
  149. static int test(DNNMathBinaryOperation op)
  150. {
  151. if (test_broadcast_input0(op))
  152. return 1;
  153. if (test_broadcast_input1(op))
  154. return 1;
  155. if (test_no_broadcast(op))
  156. return 1;
  157. return 0;
  158. }
  159. int main(int argc, char **argv)
  160. {
  161. if (test(DMBO_SUB))
  162. return 1;
  163. if (test(DMBO_ADD))
  164. return 1;
  165. return 0;
  166. }