You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

254 lines
8.2KB

  1. /*
  2. * Copyright (c) 2018 Sergey Lavrushkin
  3. *
  4. * This file is part of FFmpeg.
  5. *
  6. * FFmpeg is free software; you can redistribute it and/or
  7. * modify it under the terms of the GNU Lesser General Public
  8. * License as published by the Free Software Foundation; either
  9. * version 2.1 of the License, or (at your option) any later version.
  10. *
  11. * FFmpeg is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  14. * Lesser General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU Lesser General Public
  17. * License along with FFmpeg; if not, write to the Free Software
  18. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  19. */
  20. /**
  21. * @file
  22. * Filter implementing image super-resolution using deep convolutional networks.
  23. * https://arxiv.org/abs/1501.00092
  24. */
  25. #include "avfilter.h"
  26. #include "formats.h"
  27. #include "internal.h"
  28. #include "libavutil/opt.h"
  29. #include "libavformat/avio.h"
  30. #include "dnn_interface.h"
  31. typedef struct SRCNNContext {
  32. const AVClass *class;
  33. char* model_filename;
  34. float* input_output_buf;
  35. DNNModule* dnn_module;
  36. DNNModel* model;
  37. DNNData input_output;
  38. } SRCNNContext;
  39. #define OFFSET(x) offsetof(SRCNNContext, x)
  40. #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
  41. static const AVOption srcnn_options[] = {
  42. { "model_filename", "path to model file specifying network architecture and its parameters", OFFSET(model_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
  43. { NULL }
  44. };
  45. AVFILTER_DEFINE_CLASS(srcnn);
  46. static av_cold int init(AVFilterContext* context)
  47. {
  48. SRCNNContext* srcnn_context = context->priv;
  49. srcnn_context->dnn_module = ff_get_dnn_module(DNN_TF);
  50. if (!srcnn_context->dnn_module){
  51. srcnn_context->dnn_module = ff_get_dnn_module(DNN_NATIVE);
  52. if (!srcnn_context->dnn_module){
  53. av_log(context, AV_LOG_ERROR, "could not create dnn module\n");
  54. return AVERROR(ENOMEM);
  55. }
  56. else{
  57. av_log(context, AV_LOG_INFO, "using native backend for DNN inference\n");
  58. }
  59. }
  60. else{
  61. av_log(context, AV_LOG_INFO, "using tensorflow backend for DNN inference\n");
  62. }
  63. if (!srcnn_context->model_filename){
  64. av_log(context, AV_LOG_INFO, "model file for network was not specified, using default network for x2 upsampling\n");
  65. srcnn_context->model = (srcnn_context->dnn_module->load_default_model)(DNN_SRCNN);
  66. }
  67. else{
  68. srcnn_context->model = (srcnn_context->dnn_module->load_model)(srcnn_context->model_filename);
  69. }
  70. if (!srcnn_context->model){
  71. av_log(context, AV_LOG_ERROR, "could not load dnn model\n");
  72. return AVERROR(EIO);
  73. }
  74. return 0;
  75. }
  76. static int query_formats(AVFilterContext* context)
  77. {
  78. const enum AVPixelFormat pixel_formats[] = {AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P, AV_PIX_FMT_YUV444P,
  79. AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P, AV_PIX_FMT_GRAY8,
  80. AV_PIX_FMT_NONE};
  81. AVFilterFormats* formats_list;
  82. formats_list = ff_make_format_list(pixel_formats);
  83. if (!formats_list){
  84. av_log(context, AV_LOG_ERROR, "could not create formats list\n");
  85. return AVERROR(ENOMEM);
  86. }
  87. return ff_set_common_formats(context, formats_list);
  88. }
  89. static int config_props(AVFilterLink* inlink)
  90. {
  91. AVFilterContext* context = inlink->dst;
  92. SRCNNContext* srcnn_context = context->priv;
  93. DNNReturnType result;
  94. srcnn_context->input_output_buf = av_malloc(inlink->h * inlink->w * sizeof(float));
  95. if (!srcnn_context->input_output_buf){
  96. av_log(context, AV_LOG_ERROR, "could not allocate memory for input/output buffer\n");
  97. return AVERROR(ENOMEM);
  98. }
  99. srcnn_context->input_output.data = srcnn_context->input_output_buf;
  100. srcnn_context->input_output.width = inlink->w;
  101. srcnn_context->input_output.height = inlink->h;
  102. srcnn_context->input_output.channels = 1;
  103. result = (srcnn_context->model->set_input_output)(srcnn_context->model->model, &srcnn_context->input_output, &srcnn_context->input_output);
  104. if (result != DNN_SUCCESS){
  105. av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
  106. return AVERROR(EIO);
  107. }
  108. else{
  109. return 0;
  110. }
  111. }
  112. typedef struct ThreadData{
  113. uint8_t* out;
  114. int out_linesize, height, width;
  115. } ThreadData;
  116. static int uint8_to_float(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
  117. {
  118. SRCNNContext* srcnn_context = context->priv;
  119. const ThreadData* td = arg;
  120. const int slice_start = (td->height * jobnr ) / nb_jobs;
  121. const int slice_end = (td->height * (jobnr + 1)) / nb_jobs;
  122. const uint8_t* src = td->out + slice_start * td->out_linesize;
  123. float* dst = srcnn_context->input_output_buf + slice_start * td->width;
  124. int y, x;
  125. for (y = slice_start; y < slice_end; ++y){
  126. for (x = 0; x < td->width; ++x){
  127. dst[x] = (float)src[x] / 255.0f;
  128. }
  129. src += td->out_linesize;
  130. dst += td->width;
  131. }
  132. return 0;
  133. }
  134. static int float_to_uint8(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
  135. {
  136. SRCNNContext* srcnn_context = context->priv;
  137. const ThreadData* td = arg;
  138. const int slice_start = (td->height * jobnr ) / nb_jobs;
  139. const int slice_end = (td->height * (jobnr + 1)) / nb_jobs;
  140. const float* src = srcnn_context->input_output_buf + slice_start * td->width;
  141. uint8_t* dst = td->out + slice_start * td->out_linesize;
  142. int y, x;
  143. for (y = slice_start; y < slice_end; ++y){
  144. for (x = 0; x < td->width; ++x){
  145. dst[x] = (uint8_t)(255.0f * FFMIN(src[x], 1.0f));
  146. }
  147. src += td->width;
  148. dst += td->out_linesize;
  149. }
  150. return 0;
  151. }
  152. static int filter_frame(AVFilterLink* inlink, AVFrame* in)
  153. {
  154. AVFilterContext* context = inlink->dst;
  155. SRCNNContext* srcnn_context = context->priv;
  156. AVFilterLink* outlink = context->outputs[0];
  157. AVFrame* out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
  158. ThreadData td;
  159. int nb_threads;
  160. DNNReturnType dnn_result;
  161. if (!out){
  162. av_log(context, AV_LOG_ERROR, "could not allocate memory for output frame\n");
  163. av_frame_free(&in);
  164. return AVERROR(ENOMEM);
  165. }
  166. av_frame_copy_props(out, in);
  167. av_frame_copy(out, in);
  168. av_frame_free(&in);
  169. td.out = out->data[0];
  170. td.out_linesize = out->linesize[0];
  171. td.height = out->height;
  172. td.width = out->width;
  173. nb_threads = ff_filter_get_nb_threads(context);
  174. context->internal->execute(context, uint8_to_float, &td, NULL, FFMIN(td.height, nb_threads));
  175. dnn_result = (srcnn_context->dnn_module->execute_model)(srcnn_context->model);
  176. if (dnn_result != DNN_SUCCESS){
  177. av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
  178. return AVERROR(EIO);
  179. }
  180. context->internal->execute(context, float_to_uint8, &td, NULL, FFMIN(td.height, nb_threads));
  181. return ff_filter_frame(outlink, out);
  182. }
  183. static av_cold void uninit(AVFilterContext* context)
  184. {
  185. SRCNNContext* srcnn_context = context->priv;
  186. if (srcnn_context->dnn_module){
  187. (srcnn_context->dnn_module->free_model)(&srcnn_context->model);
  188. av_freep(&srcnn_context->dnn_module);
  189. }
  190. av_freep(&srcnn_context->input_output_buf);
  191. }
  192. static const AVFilterPad srcnn_inputs[] = {
  193. {
  194. .name = "default",
  195. .type = AVMEDIA_TYPE_VIDEO,
  196. .config_props = config_props,
  197. .filter_frame = filter_frame,
  198. },
  199. { NULL }
  200. };
  201. static const AVFilterPad srcnn_outputs[] = {
  202. {
  203. .name = "default",
  204. .type = AVMEDIA_TYPE_VIDEO,
  205. },
  206. { NULL }
  207. };
  208. AVFilter ff_vf_srcnn = {
  209. .name = "srcnn",
  210. .description = NULL_IF_CONFIG_SMALL("Apply super resolution convolutional neural network to the input. Use bicubic upsamping with corresponding scaling factor before."),
  211. .priv_size = sizeof(SRCNNContext),
  212. .init = init,
  213. .uninit = uninit,
  214. .query_formats = query_formats,
  215. .inputs = srcnn_inputs,
  216. .outputs = srcnn_outputs,
  217. .priv_class = &srcnn_class,
  218. .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_GENERIC | AVFILTER_FLAG_SLICE_THREADS,
  219. };