You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

423 lines
16KB

  1. /*
  2. * Copyright (c) 2018 Sergey Lavrushkin
  3. *
  4. * This file is part of FFmpeg.
  5. *
  6. * FFmpeg is free software; you can redistribute it and/or
  7. * modify it under the terms of the GNU Lesser General Public
  8. * License as published by the Free Software Foundation; either
  9. * version 2.1 of the License, or (at your option) any later version.
  10. *
  11. * FFmpeg is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  14. * Lesser General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU Lesser General Public
  17. * License along with FFmpeg; if not, write to the Free Software
  18. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  19. */
  20. /**
  21. * @file
  22. * Filter implementing image super-resolution using deep convolutional networks.
  23. * https://arxiv.org/abs/1501.00092
  24. */
  25. #include "avfilter.h"
  26. #include "formats.h"
  27. #include "internal.h"
  28. #include "libavutil/opt.h"
  29. #if HAVE_UNISTD_H
  30. #include <unistd.h>
  31. #endif
  32. #include "vf_srcnn.h"
  33. #include "libavformat/avio.h"
  34. typedef struct Convolution
  35. {
  36. double* kernel;
  37. double* biases;
  38. int32_t size, input_channels, output_channels;
  39. } Convolution;
  40. typedef struct SRCNNContext {
  41. const AVClass *class;
  42. /// SRCNN convolutions
  43. struct Convolution conv1, conv2, conv3;
  44. /// Path to binary file with kernels specifications
  45. char* config_file_path;
  46. /// Buffers for network input/output and feature maps
  47. double* input_output_buf;
  48. double* conv1_buf;
  49. double* conv2_buf;
  50. } SRCNNContext;
  51. #define OFFSET(x) offsetof(SRCNNContext, x)
  52. #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
  53. static const AVOption srcnn_options[] = {
  54. { "config_file", "path to configuration file with network parameters", OFFSET(config_file_path), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
  55. { NULL }
  56. };
  57. AVFILTER_DEFINE_CLASS(srcnn);
  58. #define CHECK_FILE_SIZE(file_size, srcnn_size, avio_context) if (srcnn_size > file_size){ \
  59. av_log(context, AV_LOG_ERROR, "error reading configuration file\n");\
  60. avio_closep(&avio_context); \
  61. return AVERROR(EIO); \
  62. }
  63. #define CHECK_ALLOCATION(call, end_call) if (call){ \
  64. av_log(context, AV_LOG_ERROR, "could not allocate memory for convolutions\n"); \
  65. end_call; \
  66. return AVERROR(ENOMEM); \
  67. }
  68. static int allocate_read_conv_data(Convolution* conv, AVIOContext* config_file_context)
  69. {
  70. int32_t kernel_size = conv->output_channels * conv->size * conv->size * conv->input_channels;
  71. int32_t i;
  72. conv->kernel = av_malloc(kernel_size * sizeof(double));
  73. if (!conv->kernel){
  74. return AVERROR(ENOMEM);
  75. }
  76. for (i = 0; i < kernel_size; ++i){
  77. conv->kernel[i] = av_int2double(avio_rl64(config_file_context));
  78. }
  79. conv->biases = av_malloc(conv->output_channels * sizeof(double));
  80. if (!conv->biases){
  81. return AVERROR(ENOMEM);
  82. }
  83. for (i = 0; i < conv->output_channels; ++i){
  84. conv->biases[i] = av_int2double(avio_rl64(config_file_context));
  85. }
  86. return 0;
  87. }
  88. static int allocate_copy_conv_data(Convolution* conv, const double* kernel, const double* biases)
  89. {
  90. int32_t kernel_size = conv->output_channels * conv->size * conv->size * conv->input_channels;
  91. conv->kernel = av_malloc(kernel_size * sizeof(double));
  92. if (!conv->kernel){
  93. return AVERROR(ENOMEM);
  94. }
  95. memcpy(conv->kernel, kernel, kernel_size * sizeof(double));
  96. conv->biases = av_malloc(conv->output_channels * sizeof(double));
  97. if (!conv->kernel){
  98. return AVERROR(ENOMEM);
  99. }
  100. memcpy(conv->biases, biases, conv->output_channels * sizeof(double));
  101. return 0;
  102. }
  103. static av_cold int init(AVFilterContext* context)
  104. {
  105. SRCNNContext *srcnn_context = context->priv;
  106. AVIOContext* config_file_context;
  107. int64_t file_size, srcnn_size;
  108. /// Check specified confguration file name and read network weights from it
  109. if (!srcnn_context->config_file_path){
  110. av_log(context, AV_LOG_INFO, "configuration file for network was not specified, using default weights for x2 upsampling\n");
  111. /// Create convolution kernels and copy default weights
  112. srcnn_context->conv1.input_channels = 1;
  113. srcnn_context->conv1.output_channels = 64;
  114. srcnn_context->conv1.size = 9;
  115. CHECK_ALLOCATION(allocate_copy_conv_data(&srcnn_context->conv1, conv1_kernel, conv1_biases), )
  116. srcnn_context->conv2.input_channels = 64;
  117. srcnn_context->conv2.output_channels = 32;
  118. srcnn_context->conv2.size = 1;
  119. CHECK_ALLOCATION(allocate_copy_conv_data(&srcnn_context->conv2, conv2_kernel, conv2_biases), )
  120. srcnn_context->conv3.input_channels = 32;
  121. srcnn_context->conv3.output_channels = 1;
  122. srcnn_context->conv3.size = 5;
  123. CHECK_ALLOCATION(allocate_copy_conv_data(&srcnn_context->conv3, conv3_kernel, conv3_biases), )
  124. }
  125. else if (access(srcnn_context->config_file_path, R_OK) != -1){
  126. if (avio_open(&config_file_context, srcnn_context->config_file_path, AVIO_FLAG_READ) < 0){
  127. av_log(context, AV_LOG_ERROR, "failed to open configuration file\n");
  128. return AVERROR(EIO);
  129. }
  130. file_size = avio_size(config_file_context);
  131. /// Create convolution kernels and read weights from file
  132. srcnn_context->conv1.input_channels = 1;
  133. srcnn_context->conv1.size = (int32_t)avio_rl32(config_file_context);
  134. srcnn_context->conv1.output_channels = (int32_t)avio_rl32(config_file_context);
  135. srcnn_size = 8 + (srcnn_context->conv1.output_channels * srcnn_context->conv1.size *
  136. srcnn_context->conv1.size * srcnn_context->conv1.input_channels +
  137. srcnn_context->conv1.output_channels << 3);
  138. CHECK_FILE_SIZE(file_size, srcnn_size, config_file_context)
  139. CHECK_ALLOCATION(allocate_read_conv_data(&srcnn_context->conv1, config_file_context), avio_closep(&config_file_context))
  140. srcnn_context->conv2.input_channels = (int32_t)avio_rl32(config_file_context);
  141. srcnn_context->conv2.size = (int32_t)avio_rl32(config_file_context);
  142. srcnn_context->conv2.output_channels = (int32_t)avio_rl32(config_file_context);
  143. srcnn_size += 12 + (srcnn_context->conv2.output_channels * srcnn_context->conv2.size *
  144. srcnn_context->conv2.size * srcnn_context->conv2.input_channels +
  145. srcnn_context->conv2.output_channels << 3);
  146. CHECK_FILE_SIZE(file_size, srcnn_size, config_file_context)
  147. CHECK_ALLOCATION(allocate_read_conv_data(&srcnn_context->conv2, config_file_context), avio_closep(&config_file_context))
  148. srcnn_context->conv3.input_channels = (int32_t)avio_rl32(config_file_context);
  149. srcnn_context->conv3.size = (int32_t)avio_rl32(config_file_context);
  150. srcnn_context->conv3.output_channels = 1;
  151. srcnn_size += 8 + (srcnn_context->conv3.output_channels * srcnn_context->conv3.size *
  152. srcnn_context->conv3.size * srcnn_context->conv3.input_channels
  153. + srcnn_context->conv3.output_channels << 3);
  154. if (file_size != srcnn_size){
  155. av_log(context, AV_LOG_ERROR, "error reading configuration file\n");
  156. avio_closep(&config_file_context);
  157. return AVERROR(EIO);
  158. }
  159. CHECK_ALLOCATION(allocate_read_conv_data(&srcnn_context->conv3, config_file_context), avio_closep(&config_file_context))
  160. avio_closep(&config_file_context);
  161. }
  162. else{
  163. av_log(context, AV_LOG_ERROR, "specified configuration file does not exist or not readable\n");
  164. return AVERROR(EIO);
  165. }
  166. return 0;
  167. }
  168. static int query_formats(AVFilterContext* context)
  169. {
  170. const enum AVPixelFormat pixel_formats[] = {AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P, AV_PIX_FMT_YUV444P,
  171. AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P, AV_PIX_FMT_GRAY8,
  172. AV_PIX_FMT_NONE};
  173. AVFilterFormats *formats_list;
  174. formats_list = ff_make_format_list(pixel_formats);
  175. if (!formats_list){
  176. av_log(context, AV_LOG_ERROR, "could not create formats list\n");
  177. return AVERROR(ENOMEM);
  178. }
  179. return ff_set_common_formats(context, formats_list);
  180. }
  181. static int config_props(AVFilterLink* inlink)
  182. {
  183. AVFilterContext *context = inlink->dst;
  184. SRCNNContext *srcnn_context = context->priv;
  185. int min_dim;
  186. /// Check if input data width or height is too low
  187. min_dim = FFMIN(inlink->w, inlink->h);
  188. if (min_dim <= srcnn_context->conv1.size >> 1 || min_dim <= srcnn_context->conv2.size >> 1 || min_dim <= srcnn_context->conv3.size >> 1){
  189. av_log(context, AV_LOG_ERROR, "input width or height is too low\n");
  190. return AVERROR(EIO);
  191. }
  192. /// Allocate network buffers
  193. srcnn_context->input_output_buf = av_malloc(inlink->h * inlink->w * sizeof(double));
  194. srcnn_context->conv1_buf = av_malloc(inlink->h * inlink->w * srcnn_context->conv1.output_channels * sizeof(double));
  195. srcnn_context->conv2_buf = av_malloc(inlink->h * inlink->w * srcnn_context->conv2.output_channels * sizeof(double));
  196. if (!srcnn_context->input_output_buf || !srcnn_context->conv1_buf || !srcnn_context->conv2_buf){
  197. av_log(context, AV_LOG_ERROR, "could not allocate memory for srcnn buffers\n");
  198. return AVERROR(ENOMEM);
  199. }
  200. return 0;
  201. }
  202. typedef struct ThreadData{
  203. uint8_t* out;
  204. int out_linesize, height, width;
  205. } ThreadData;
  206. typedef struct ConvThreadData
  207. {
  208. const Convolution* conv;
  209. const double* input;
  210. double* output;
  211. int height, width;
  212. } ConvThreadData;
  213. /// Convert uint8 data to double and scale it to use in network
  214. static int uint8_to_double(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
  215. {
  216. SRCNNContext* srcnn_context = context->priv;
  217. const ThreadData* td = arg;
  218. const int slice_start = (td->height * jobnr ) / nb_jobs;
  219. const int slice_end = (td->height * (jobnr + 1)) / nb_jobs;
  220. const uint8_t* src = td->out + slice_start * td->out_linesize;
  221. double* dst = srcnn_context->input_output_buf + slice_start * td->width;
  222. int y, x;
  223. for (y = slice_start; y < slice_end; ++y){
  224. for (x = 0; x < td->width; ++x){
  225. dst[x] = (double)src[x] / 255.0;
  226. }
  227. src += td->out_linesize;
  228. dst += td->width;
  229. }
  230. return 0;
  231. }
  232. /// Convert double data from network to uint8 and scale it to output as filter result
  233. static int double_to_uint8(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
  234. {
  235. SRCNNContext* srcnn_context = context->priv;
  236. const ThreadData* td = arg;
  237. const int slice_start = (td->height * jobnr ) / nb_jobs;
  238. const int slice_end = (td->height * (jobnr + 1)) / nb_jobs;
  239. const double* src = srcnn_context->input_output_buf + slice_start * td->width;
  240. uint8_t* dst = td->out + slice_start * td->out_linesize;
  241. int y, x;
  242. for (y = slice_start; y < slice_end; ++y){
  243. for (x = 0; x < td->width; ++x){
  244. dst[x] = (uint8_t)(255.0 * FFMIN(src[x], 1.0));
  245. }
  246. src += td->width;
  247. dst += td->out_linesize;
  248. }
  249. return 0;
  250. }
  251. #define CLAMP_TO_EDGE(x, w) ((x) < 0 ? 0 : ((x) >= (w) ? (w - 1) : (x)))
  252. static int convolve(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
  253. {
  254. const ConvThreadData* td = arg;
  255. const int slice_start = (td->height * jobnr ) / nb_jobs;
  256. const int slice_end = (td->height * (jobnr + 1)) / nb_jobs;
  257. const double* src = td->input;
  258. double* dst = td->output + slice_start * td->width * td->conv->output_channels;
  259. int y, x;
  260. int32_t n_filter, ch, kernel_y, kernel_x;
  261. int32_t radius = td->conv->size >> 1;
  262. int src_linesize = td->width * td->conv->input_channels;
  263. int filter_linesize = td->conv->size * td->conv->input_channels;
  264. int filter_size = td->conv->size * filter_linesize;
  265. for (y = slice_start; y < slice_end; ++y){
  266. for (x = 0; x < td->width; ++x){
  267. for (n_filter = 0; n_filter < td->conv->output_channels; ++n_filter){
  268. dst[n_filter] = td->conv->biases[n_filter];
  269. for (ch = 0; ch < td->conv->input_channels; ++ch){
  270. for (kernel_y = 0; kernel_y < td->conv->size; ++kernel_y){
  271. for (kernel_x = 0; kernel_x < td->conv->size; ++kernel_x){
  272. dst[n_filter] += src[CLAMP_TO_EDGE(y + kernel_y - radius, td->height) * src_linesize +
  273. CLAMP_TO_EDGE(x + kernel_x - radius, td->width) * td->conv->input_channels + ch] *
  274. td->conv->kernel[n_filter * filter_size + kernel_y * filter_linesize +
  275. kernel_x * td->conv->input_channels + ch];
  276. }
  277. }
  278. }
  279. dst[n_filter] = FFMAX(dst[n_filter], 0.0);
  280. }
  281. dst += td->conv->output_channels;
  282. }
  283. }
  284. return 0;
  285. }
  286. static int filter_frame(AVFilterLink* inlink, AVFrame* in)
  287. {
  288. AVFilterContext* context = inlink->dst;
  289. SRCNNContext* srcnn_context = context->priv;
  290. AVFilterLink* outlink = context->outputs[0];
  291. AVFrame* out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
  292. ThreadData td;
  293. ConvThreadData ctd;
  294. if (!out){
  295. av_log(context, AV_LOG_ERROR, "could not allocate memory for output frame\n");
  296. av_frame_free(&in);
  297. return AVERROR(ENOMEM);
  298. }
  299. av_frame_copy_props(out, in);
  300. av_frame_copy(out, in);
  301. av_frame_free(&in);
  302. td.out = out->data[0];
  303. td.out_linesize = out->linesize[0];
  304. td.height = ctd.height = out->height;
  305. td.width = ctd.width = out->width;
  306. context->internal->execute(context, uint8_to_double, &td, NULL, FFMIN(td.height, context->graph->nb_threads));
  307. ctd.conv = &srcnn_context->conv1;
  308. ctd.input = srcnn_context->input_output_buf;
  309. ctd.output = srcnn_context->conv1_buf;
  310. context->internal->execute(context, convolve, &ctd, NULL, FFMIN(ctd.height, context->graph->nb_threads));
  311. ctd.conv = &srcnn_context->conv2;
  312. ctd.input = srcnn_context->conv1_buf;
  313. ctd.output = srcnn_context->conv2_buf;
  314. context->internal->execute(context, convolve, &ctd, NULL, FFMIN(ctd.height, context->graph->nb_threads));
  315. ctd.conv = &srcnn_context->conv3;
  316. ctd.input = srcnn_context->conv2_buf;
  317. ctd.output = srcnn_context->input_output_buf;
  318. context->internal->execute(context, convolve, &ctd, NULL, FFMIN(ctd.height, context->graph->nb_threads));
  319. context->internal->execute(context, double_to_uint8, &td, NULL, FFMIN(td.height, context->graph->nb_threads));
  320. return ff_filter_frame(outlink, out);
  321. }
  322. static av_cold void uninit(AVFilterContext* context)
  323. {
  324. SRCNNContext* srcnn_context = context->priv;
  325. /// Free convolution data
  326. av_freep(&srcnn_context->conv1.kernel);
  327. av_freep(&srcnn_context->conv1.biases);
  328. av_freep(&srcnn_context->conv2.kernel);
  329. av_freep(&srcnn_context->conv2.biases);
  330. av_freep(&srcnn_context->conv3.kernel);
  331. av_freep(&srcnn_context->conv3.kernel);
  332. /// Free network buffers
  333. av_freep(&srcnn_context->input_output_buf);
  334. av_freep(&srcnn_context->conv1_buf);
  335. av_freep(&srcnn_context->conv2_buf);
  336. }
  337. static const AVFilterPad srcnn_inputs[] = {
  338. {
  339. .name = "default",
  340. .type = AVMEDIA_TYPE_VIDEO,
  341. .config_props = config_props,
  342. .filter_frame = filter_frame,
  343. },
  344. { NULL }
  345. };
  346. static const AVFilterPad srcnn_outputs[] = {
  347. {
  348. .name = "default",
  349. .type = AVMEDIA_TYPE_VIDEO,
  350. },
  351. { NULL }
  352. };
  353. AVFilter ff_vf_srcnn = {
  354. .name = "srcnn",
  355. .description = NULL_IF_CONFIG_SMALL("Apply super resolution convolutional neural network to the input. Use bicubic upsamping with corresponding scaling factor before."),
  356. .priv_size = sizeof(SRCNNContext),
  357. .init = init,
  358. .uninit = uninit,
  359. .query_formats = query_formats,
  360. .inputs = srcnn_inputs,
  361. .outputs = srcnn_outputs,
  362. .priv_class = &srcnn_class,
  363. .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_GENERIC | AVFILTER_FLAG_SLICE_THREADS,
  364. };