You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

251 lines
8.4KB

  1. /*
  2. * Copyright (c) 2018 Sergey Lavrushkin
  3. *
  4. * This file is part of FFmpeg.
  5. *
  6. * FFmpeg is free software; you can redistribute it and/or
  7. * modify it under the terms of the GNU Lesser General Public
  8. * License as published by the Free Software Foundation; either
  9. * version 2.1 of the License, or (at your option) any later version.
  10. *
  11. * FFmpeg is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  14. * Lesser General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU Lesser General Public
  17. * License along with FFmpeg; if not, write to the Free Software
  18. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  19. */
  20. /**
  21. * @file
  22. * Filter implementing image super-resolution using deep convolutional networks.
  23. * https://arxiv.org/abs/1501.00092
  24. */
  25. #include "avfilter.h"
  26. #include "formats.h"
  27. #include "internal.h"
  28. #include "libavutil/opt.h"
  29. #include "libavformat/avio.h"
  30. #include "dnn_interface.h"
  31. typedef struct SRCNNContext {
  32. const AVClass *class;
  33. char* model_filename;
  34. float* input_output_buf;
  35. DNNBackendType backend_type;
  36. DNNModule* dnn_module;
  37. DNNModel* model;
  38. DNNData input_output;
  39. } SRCNNContext;
  40. #define OFFSET(x) offsetof(SRCNNContext, x)
  41. #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
  42. static const AVOption srcnn_options[] = {
  43. { "dnn_backend", "DNN backend used for model execution", OFFSET(backend_type), AV_OPT_TYPE_FLAGS, { .i64 = 0 }, 0, 1, FLAGS, "backend" },
  44. { "native", "native backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 0 }, 0, 0, FLAGS, "backend" },
  45. #if (CONFIG_LIBTENSORFLOW == 1)
  46. { "tensorflow", "tensorflow backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 1 }, 0, 0, FLAGS, "backend" },
  47. #endif
  48. { "model_filename", "path to model file specifying network architecture and its parameters", OFFSET(model_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
  49. { NULL }
  50. };
  51. AVFILTER_DEFINE_CLASS(srcnn);
  52. static av_cold int init(AVFilterContext* context)
  53. {
  54. SRCNNContext* srcnn_context = context->priv;
  55. srcnn_context->dnn_module = ff_get_dnn_module(srcnn_context->backend_type);
  56. if (!srcnn_context->dnn_module){
  57. av_log(context, AV_LOG_ERROR, "could not create DNN module for requested backend\n");
  58. return AVERROR(ENOMEM);
  59. }
  60. if (!srcnn_context->model_filename){
  61. av_log(context, AV_LOG_VERBOSE, "model file for network was not specified, using default network for x2 upsampling\n");
  62. srcnn_context->model = (srcnn_context->dnn_module->load_default_model)(DNN_SRCNN);
  63. }
  64. else{
  65. srcnn_context->model = (srcnn_context->dnn_module->load_model)(srcnn_context->model_filename);
  66. }
  67. if (!srcnn_context->model){
  68. av_log(context, AV_LOG_ERROR, "could not load DNN model\n");
  69. return AVERROR(EIO);
  70. }
  71. return 0;
  72. }
  73. static int query_formats(AVFilterContext* context)
  74. {
  75. const enum AVPixelFormat pixel_formats[] = {AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P, AV_PIX_FMT_YUV444P,
  76. AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P, AV_PIX_FMT_GRAY8,
  77. AV_PIX_FMT_NONE};
  78. AVFilterFormats* formats_list;
  79. formats_list = ff_make_format_list(pixel_formats);
  80. if (!formats_list){
  81. av_log(context, AV_LOG_ERROR, "could not create formats list\n");
  82. return AVERROR(ENOMEM);
  83. }
  84. return ff_set_common_formats(context, formats_list);
  85. }
  86. static int config_props(AVFilterLink* inlink)
  87. {
  88. AVFilterContext* context = inlink->dst;
  89. SRCNNContext* srcnn_context = context->priv;
  90. DNNReturnType result;
  91. srcnn_context->input_output_buf = av_malloc(inlink->h * inlink->w * sizeof(float));
  92. if (!srcnn_context->input_output_buf){
  93. av_log(context, AV_LOG_ERROR, "could not allocate memory for input/output buffer\n");
  94. return AVERROR(ENOMEM);
  95. }
  96. srcnn_context->input_output.data = srcnn_context->input_output_buf;
  97. srcnn_context->input_output.width = inlink->w;
  98. srcnn_context->input_output.height = inlink->h;
  99. srcnn_context->input_output.channels = 1;
  100. result = (srcnn_context->model->set_input_output)(srcnn_context->model->model, &srcnn_context->input_output, &srcnn_context->input_output);
  101. if (result != DNN_SUCCESS){
  102. av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
  103. return AVERROR(EIO);
  104. }
  105. else{
  106. return 0;
  107. }
  108. }
  109. typedef struct ThreadData{
  110. uint8_t* out;
  111. int out_linesize, height, width;
  112. } ThreadData;
  113. static int uint8_to_float(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
  114. {
  115. SRCNNContext* srcnn_context = context->priv;
  116. const ThreadData* td = arg;
  117. const int slice_start = (td->height * jobnr ) / nb_jobs;
  118. const int slice_end = (td->height * (jobnr + 1)) / nb_jobs;
  119. const uint8_t* src = td->out + slice_start * td->out_linesize;
  120. float* dst = srcnn_context->input_output_buf + slice_start * td->width;
  121. int y, x;
  122. for (y = slice_start; y < slice_end; ++y){
  123. for (x = 0; x < td->width; ++x){
  124. dst[x] = (float)src[x] / 255.0f;
  125. }
  126. src += td->out_linesize;
  127. dst += td->width;
  128. }
  129. return 0;
  130. }
  131. static int float_to_uint8(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
  132. {
  133. SRCNNContext* srcnn_context = context->priv;
  134. const ThreadData* td = arg;
  135. const int slice_start = (td->height * jobnr ) / nb_jobs;
  136. const int slice_end = (td->height * (jobnr + 1)) / nb_jobs;
  137. const float* src = srcnn_context->input_output_buf + slice_start * td->width;
  138. uint8_t* dst = td->out + slice_start * td->out_linesize;
  139. int y, x;
  140. for (y = slice_start; y < slice_end; ++y){
  141. for (x = 0; x < td->width; ++x){
  142. dst[x] = (uint8_t)(255.0f * FFMIN(src[x], 1.0f));
  143. }
  144. src += td->width;
  145. dst += td->out_linesize;
  146. }
  147. return 0;
  148. }
  149. static int filter_frame(AVFilterLink* inlink, AVFrame* in)
  150. {
  151. AVFilterContext* context = inlink->dst;
  152. SRCNNContext* srcnn_context = context->priv;
  153. AVFilterLink* outlink = context->outputs[0];
  154. AVFrame* out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
  155. ThreadData td;
  156. int nb_threads;
  157. DNNReturnType dnn_result;
  158. if (!out){
  159. av_log(context, AV_LOG_ERROR, "could not allocate memory for output frame\n");
  160. av_frame_free(&in);
  161. return AVERROR(ENOMEM);
  162. }
  163. av_frame_copy_props(out, in);
  164. av_frame_copy(out, in);
  165. av_frame_free(&in);
  166. td.out = out->data[0];
  167. td.out_linesize = out->linesize[0];
  168. td.height = out->height;
  169. td.width = out->width;
  170. nb_threads = ff_filter_get_nb_threads(context);
  171. context->internal->execute(context, uint8_to_float, &td, NULL, FFMIN(td.height, nb_threads));
  172. dnn_result = (srcnn_context->dnn_module->execute_model)(srcnn_context->model);
  173. if (dnn_result != DNN_SUCCESS){
  174. av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
  175. return AVERROR(EIO);
  176. }
  177. context->internal->execute(context, float_to_uint8, &td, NULL, FFMIN(td.height, nb_threads));
  178. return ff_filter_frame(outlink, out);
  179. }
  180. static av_cold void uninit(AVFilterContext* context)
  181. {
  182. SRCNNContext* srcnn_context = context->priv;
  183. if (srcnn_context->dnn_module){
  184. (srcnn_context->dnn_module->free_model)(&srcnn_context->model);
  185. av_freep(&srcnn_context->dnn_module);
  186. }
  187. av_freep(&srcnn_context->input_output_buf);
  188. }
  189. static const AVFilterPad srcnn_inputs[] = {
  190. {
  191. .name = "default",
  192. .type = AVMEDIA_TYPE_VIDEO,
  193. .config_props = config_props,
  194. .filter_frame = filter_frame,
  195. },
  196. { NULL }
  197. };
  198. static const AVFilterPad srcnn_outputs[] = {
  199. {
  200. .name = "default",
  201. .type = AVMEDIA_TYPE_VIDEO,
  202. },
  203. { NULL }
  204. };
  205. AVFilter ff_vf_srcnn = {
  206. .name = "srcnn",
  207. .description = NULL_IF_CONFIG_SMALL("Apply super resolution convolutional neural network to the input. Use bicubic upsamping with corresponding scaling factor before."),
  208. .priv_size = sizeof(SRCNNContext),
  209. .init = init,
  210. .uninit = uninit,
  211. .query_formats = query_formats,
  212. .inputs = srcnn_inputs,
  213. .outputs = srcnn_outputs,
  214. .priv_class = &srcnn_class,
  215. .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_GENERIC | AVFILTER_FLAG_SLICE_THREADS,
  216. };