|
|
|
@@ -104,12 +104,20 @@ static int query_formats(AVFilterContext *context) |
|
|
|
{ |
|
|
|
static const enum AVPixelFormat pix_fmts[] = { |
|
|
|
AV_PIX_FMT_RGB24, AV_PIX_FMT_BGR24, |
|
|
|
AV_PIX_FMT_GRAY8, AV_PIX_FMT_GRAYF32, |
|
|
|
AV_PIX_FMT_NONE |
|
|
|
}; |
|
|
|
AVFilterFormats *fmts_list = ff_make_format_list(pix_fmts); |
|
|
|
return ff_set_common_formats(context, fmts_list); |
|
|
|
} |
|
|
|
|
|
|
|
#define LOG_FORMAT_CHANNEL_MISMATCH() \ |
|
|
|
av_log(ctx, AV_LOG_ERROR, \ |
|
|
|
"the frame's format %s does not match " \ |
|
|
|
"the model input channel %d\n", \ |
|
|
|
av_get_pix_fmt_name(fmt), \ |
|
|
|
model_input->channels); |
|
|
|
|
|
|
|
static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink) |
|
|
|
{ |
|
|
|
AVFilterContext *ctx = inlink->dst; |
|
|
|
@@ -131,17 +139,34 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin |
|
|
|
case AV_PIX_FMT_RGB24: |
|
|
|
case AV_PIX_FMT_BGR24: |
|
|
|
if (model_input->channels != 3) { |
|
|
|
av_log(ctx, AV_LOG_ERROR, "the frame's input format %s does not match " |
|
|
|
"the model input channels %d\n", |
|
|
|
av_get_pix_fmt_name(fmt), |
|
|
|
model_input->channels); |
|
|
|
LOG_FORMAT_CHANNEL_MISMATCH(); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) { |
|
|
|
av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n"); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
break; |
|
|
|
return 0; |
|
|
|
case AV_PIX_FMT_GRAY8: |
|
|
|
if (model_input->channels != 1) { |
|
|
|
LOG_FORMAT_CHANNEL_MISMATCH(); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
if (model_input->dt != DNN_UINT8) { |
|
|
|
av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n"); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
return 0; |
|
|
|
case AV_PIX_FMT_GRAYF32: |
|
|
|
if (model_input->channels != 1) { |
|
|
|
LOG_FORMAT_CHANNEL_MISMATCH(); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
if (model_input->dt != DNN_FLOAT) { |
|
|
|
av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n"); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
return 0; |
|
|
|
default: |
|
|
|
av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt)); |
|
|
|
return AVERROR(EIO); |
|
|
|
@@ -206,28 +231,58 @@ static int config_output(AVFilterLink *outlink) |
|
|
|
|
|
|
|
static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame) |
|
|
|
{ |
|
|
|
// extend this function to support more formats |
|
|
|
av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24); |
|
|
|
|
|
|
|
if (dnn_input->dt == DNN_FLOAT) { |
|
|
|
float *dnn_input_data = dnn_input->data; |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
dnn_input_data[t] = frame->data[0][k] / 255.0f; |
|
|
|
switch (frame->format) { |
|
|
|
case AV_PIX_FMT_RGB24: |
|
|
|
case AV_PIX_FMT_BGR24: |
|
|
|
if (dnn_input->dt == DNN_FLOAT) { |
|
|
|
float *dnn_input_data = dnn_input->data; |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
dnn_input_data[t] = frame->data[0][k] / 255.0f; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
uint8_t *dnn_input_data = dnn_input->data; |
|
|
|
av_assert0(dnn_input->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
dnn_input_data[t] = frame->data[0][k]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
uint8_t *dnn_input_data = dnn_input->data; |
|
|
|
av_assert0(dnn_input->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
dnn_input_data[t] = frame->data[0][k]; |
|
|
|
return 0; |
|
|
|
case AV_PIX_FMT_GRAY8: |
|
|
|
{ |
|
|
|
uint8_t *dnn_input_data = dnn_input->data; |
|
|
|
av_assert0(dnn_input->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width + j; |
|
|
|
dnn_input_data[t] = frame->data[0][k]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return 0; |
|
|
|
case AV_PIX_FMT_GRAYF32: |
|
|
|
{ |
|
|
|
float *dnn_input_data = dnn_input->data; |
|
|
|
av_assert0(dnn_input->dt == DNN_FLOAT); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width; j++) { |
|
|
|
int k = i * frame->linesize[0] + j * sizeof(float); |
|
|
|
int t = i * frame->width + j; |
|
|
|
dnn_input_data[t] = *(float*)(frame->data[0] + k); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return 0; |
|
|
|
default: |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
|
|
|
|
return 0; |
|
|
|
@@ -235,28 +290,58 @@ static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame) |
|
|
|
|
|
|
|
static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData *dnn_output) |
|
|
|
{ |
|
|
|
// extend this function to support more formats |
|
|
|
av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24); |
|
|
|
|
|
|
|
if (dnn_output->dt == DNN_FLOAT) { |
|
|
|
float *dnn_output_data = dnn_output->data; |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8); |
|
|
|
switch (frame->format) { |
|
|
|
case AV_PIX_FMT_RGB24: |
|
|
|
case AV_PIX_FMT_BGR24: |
|
|
|
if (dnn_output->dt == DNN_FLOAT) { |
|
|
|
float *dnn_output_data = dnn_output->data; |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
uint8_t *dnn_output_data = dnn_output->data; |
|
|
|
av_assert0(dnn_output->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
frame->data[0][k] = dnn_output_data[t]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return 0; |
|
|
|
case AV_PIX_FMT_GRAY8: |
|
|
|
{ |
|
|
|
uint8_t *dnn_output_data = dnn_output->data; |
|
|
|
av_assert0(dnn_output->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width + j; |
|
|
|
frame->data[0][k] = dnn_output_data[t]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
uint8_t *dnn_output_data = dnn_output->data; |
|
|
|
av_assert0(dnn_output->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
frame->data[0][k] = dnn_output_data[t]; |
|
|
|
return 0; |
|
|
|
case AV_PIX_FMT_GRAYF32: |
|
|
|
{ |
|
|
|
float *dnn_output_data = dnn_output->data; |
|
|
|
av_assert0(dnn_output->dt == DNN_FLOAT); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width; j++) { |
|
|
|
int k = i * frame->linesize[0] + j * sizeof(float); |
|
|
|
int t = i * frame->width + j; |
|
|
|
*(float*)(frame->data[0] + k) = dnn_output_data[t]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return 0; |
|
|
|
default: |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
|
|
|
|
return 0; |
|
|
|
@@ -278,7 +363,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) |
|
|
|
av_frame_free(&in); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
av_assert0(ctx->output.channels == 3); |
|
|
|
|
|
|
|
out = ff_get_video_buffer(outlink, outlink->w, outlink->h); |
|
|
|
if (!out) { |
|
|
|
|