You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

246 lines
7.9KB

  1. /*
  2. * Copyright (c) 2018 Sergey Lavrushkin
  3. *
  4. * This file is part of FFmpeg.
  5. *
  6. * FFmpeg is free software; you can redistribute it and/or
  7. * modify it under the terms of the GNU Lesser General Public
  8. * License as published by the Free Software Foundation; either
  9. * version 2.1 of the License, or (at your option) any later version.
  10. *
  11. * FFmpeg is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  14. * Lesser General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU Lesser General Public
  17. * License along with FFmpeg; if not, write to the Free Software
  18. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  19. */
  20. /**
  21. * @file
  22. * Filter implementing image super-resolution using deep convolutional networks.
  23. * https://arxiv.org/abs/1501.00092
  24. */
  25. #include "avfilter.h"
  26. #include "formats.h"
  27. #include "internal.h"
  28. #include "libavutil/opt.h"
  29. #include "libavformat/avio.h"
  30. #include "dnn_interface.h"
  31. typedef struct SRCNNContext {
  32. const AVClass *class;
  33. char* model_filename;
  34. float* input_output_buf;
  35. DNNModule* dnn_module;
  36. DNNModel* model;
  37. DNNData input_output;
  38. } SRCNNContext;
  39. #define OFFSET(x) offsetof(SRCNNContext, x)
  40. #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
  41. static const AVOption srcnn_options[] = {
  42. { "model_filename", "path to model file specifying network architecture and its parameters", OFFSET(model_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
  43. { NULL }
  44. };
  45. AVFILTER_DEFINE_CLASS(srcnn);
  46. static av_cold int init(AVFilterContext* context)
  47. {
  48. SRCNNContext* srcnn_context = context->priv;
  49. srcnn_context->dnn_module = ff_get_dnn_module(DNN_NATIVE);
  50. if (!srcnn_context->dnn_module){
  51. av_log(context, AV_LOG_ERROR, "could not create dnn module\n");
  52. return AVERROR(ENOMEM);
  53. }
  54. if (!srcnn_context->model_filename){
  55. av_log(context, AV_LOG_INFO, "model file for network was not specified, using default network for x2 upsampling\n");
  56. srcnn_context->model = (srcnn_context->dnn_module->load_default_model)(DNN_SRCNN);
  57. }
  58. else{
  59. srcnn_context->model = (srcnn_context->dnn_module->load_model)(srcnn_context->model_filename);
  60. }
  61. if (!srcnn_context->model){
  62. av_log(context, AV_LOG_ERROR, "could not load dnn model\n");
  63. return AVERROR(EIO);
  64. }
  65. return 0;
  66. }
  67. static int query_formats(AVFilterContext* context)
  68. {
  69. const enum AVPixelFormat pixel_formats[] = {AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P, AV_PIX_FMT_YUV444P,
  70. AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P, AV_PIX_FMT_GRAY8,
  71. AV_PIX_FMT_NONE};
  72. AVFilterFormats* formats_list;
  73. formats_list = ff_make_format_list(pixel_formats);
  74. if (!formats_list){
  75. av_log(context, AV_LOG_ERROR, "could not create formats list\n");
  76. return AVERROR(ENOMEM);
  77. }
  78. return ff_set_common_formats(context, formats_list);
  79. }
  80. static int config_props(AVFilterLink* inlink)
  81. {
  82. AVFilterContext* context = inlink->dst;
  83. SRCNNContext* srcnn_context = context->priv;
  84. DNNReturnType result;
  85. srcnn_context->input_output_buf = av_malloc(inlink->h * inlink->w * sizeof(float));
  86. if (!srcnn_context->input_output_buf){
  87. av_log(context, AV_LOG_ERROR, "could not allocate memory for input/output buffer\n");
  88. return AVERROR(ENOMEM);
  89. }
  90. srcnn_context->input_output.data = srcnn_context->input_output_buf;
  91. srcnn_context->input_output.width = inlink->w;
  92. srcnn_context->input_output.height = inlink->h;
  93. srcnn_context->input_output.channels = 1;
  94. result = (srcnn_context->model->set_input_output)(srcnn_context->model->model, &srcnn_context->input_output, &srcnn_context->input_output);
  95. if (result != DNN_SUCCESS){
  96. av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
  97. return AVERROR(EIO);
  98. }
  99. else{
  100. return 0;
  101. }
  102. }
  103. typedef struct ThreadData{
  104. uint8_t* out;
  105. int out_linesize, height, width;
  106. } ThreadData;
  107. static int uint8_to_float(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
  108. {
  109. SRCNNContext* srcnn_context = context->priv;
  110. const ThreadData* td = arg;
  111. const int slice_start = (td->height * jobnr ) / nb_jobs;
  112. const int slice_end = (td->height * (jobnr + 1)) / nb_jobs;
  113. const uint8_t* src = td->out + slice_start * td->out_linesize;
  114. float* dst = srcnn_context->input_output_buf + slice_start * td->width;
  115. int y, x;
  116. for (y = slice_start; y < slice_end; ++y){
  117. for (x = 0; x < td->width; ++x){
  118. dst[x] = (float)src[x] / 255.0f;
  119. }
  120. src += td->out_linesize;
  121. dst += td->width;
  122. }
  123. return 0;
  124. }
  125. static int float_to_uint8(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
  126. {
  127. SRCNNContext* srcnn_context = context->priv;
  128. const ThreadData* td = arg;
  129. const int slice_start = (td->height * jobnr ) / nb_jobs;
  130. const int slice_end = (td->height * (jobnr + 1)) / nb_jobs;
  131. const float* src = srcnn_context->input_output_buf + slice_start * td->width;
  132. uint8_t* dst = td->out + slice_start * td->out_linesize;
  133. int y, x;
  134. for (y = slice_start; y < slice_end; ++y){
  135. for (x = 0; x < td->width; ++x){
  136. dst[x] = (uint8_t)(255.0f * FFMIN(src[x], 1.0f));
  137. }
  138. src += td->width;
  139. dst += td->out_linesize;
  140. }
  141. return 0;
  142. }
  143. static int filter_frame(AVFilterLink* inlink, AVFrame* in)
  144. {
  145. AVFilterContext* context = inlink->dst;
  146. SRCNNContext* srcnn_context = context->priv;
  147. AVFilterLink* outlink = context->outputs[0];
  148. AVFrame* out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
  149. ThreadData td;
  150. int nb_threads;
  151. DNNReturnType dnn_result;
  152. if (!out){
  153. av_log(context, AV_LOG_ERROR, "could not allocate memory for output frame\n");
  154. av_frame_free(&in);
  155. return AVERROR(ENOMEM);
  156. }
  157. av_frame_copy_props(out, in);
  158. av_frame_copy(out, in);
  159. av_frame_free(&in);
  160. td.out = out->data[0];
  161. td.out_linesize = out->linesize[0];
  162. td.height = out->height;
  163. td.width = out->width;
  164. nb_threads = ff_filter_get_nb_threads(context);
  165. context->internal->execute(context, uint8_to_float, &td, NULL, FFMIN(td.height, nb_threads));
  166. dnn_result = (srcnn_context->dnn_module->execute_model)(srcnn_context->model);
  167. if (dnn_result != DNN_SUCCESS){
  168. av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
  169. return AVERROR(EIO);
  170. }
  171. context->internal->execute(context, float_to_uint8, &td, NULL, FFMIN(td.height, nb_threads));
  172. return ff_filter_frame(outlink, out);
  173. }
  174. static av_cold void uninit(AVFilterContext* context)
  175. {
  176. SRCNNContext* srcnn_context = context->priv;
  177. if (srcnn_context->dnn_module){
  178. (srcnn_context->dnn_module->free_model)(&srcnn_context->model);
  179. av_freep(&srcnn_context->dnn_module);
  180. }
  181. av_freep(&srcnn_context->input_output_buf);
  182. }
  183. static const AVFilterPad srcnn_inputs[] = {
  184. {
  185. .name = "default",
  186. .type = AVMEDIA_TYPE_VIDEO,
  187. .config_props = config_props,
  188. .filter_frame = filter_frame,
  189. },
  190. { NULL }
  191. };
  192. static const AVFilterPad srcnn_outputs[] = {
  193. {
  194. .name = "default",
  195. .type = AVMEDIA_TYPE_VIDEO,
  196. },
  197. { NULL }
  198. };
  199. AVFilter ff_vf_srcnn = {
  200. .name = "srcnn",
  201. .description = NULL_IF_CONFIG_SMALL("Apply super resolution convolutional neural network to the input. Use bicubic upsamping with corresponding scaling factor before."),
  202. .priv_size = sizeof(SRCNNContext),
  203. .init = init,
  204. .uninit = uninit,
  205. .query_formats = query_formats,
  206. .inputs = srcnn_inputs,
  207. .outputs = srcnn_outputs,
  208. .priv_class = &srcnn_class,
  209. .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_GENERIC | AVFILTER_FLAG_SLICE_THREADS,
  210. };