|
|
|
@@ -136,40 +136,40 @@ static int config_input(AVFilterLink *inlink) |
|
|
|
AVFilterContext *context = inlink->dst; |
|
|
|
DnnProcessingContext *ctx = context->priv; |
|
|
|
DNNReturnType result; |
|
|
|
DNNData dnn_data; |
|
|
|
DNNData model_input; |
|
|
|
|
|
|
|
result = ctx->model->get_input(ctx->model->model, &dnn_data, ctx->model_inputname); |
|
|
|
result = ctx->model->get_input(ctx->model->model, &model_input, ctx->model_inputname); |
|
|
|
if (result != DNN_SUCCESS) { |
|
|
|
av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n"); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
|
|
|
|
// the design is to add explicit scale filter before this filter |
|
|
|
if (dnn_data.height != -1 && dnn_data.height != inlink->h) { |
|
|
|
if (model_input.height != -1 && model_input.height != inlink->h) { |
|
|
|
av_log(ctx, AV_LOG_ERROR, "the model requires frame height %d but got %d\n", |
|
|
|
dnn_data.height, inlink->h); |
|
|
|
model_input.height, inlink->h); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
if (dnn_data.width != -1 && dnn_data.width != inlink->w) { |
|
|
|
if (model_input.width != -1 && model_input.width != inlink->w) { |
|
|
|
av_log(ctx, AV_LOG_ERROR, "the model requires frame width %d but got %d\n", |
|
|
|
dnn_data.width, inlink->w); |
|
|
|
model_input.width, inlink->w); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
|
|
|
|
if (dnn_data.channels != 3) { |
|
|
|
if (model_input.channels != 3) { |
|
|
|
av_log(ctx, AV_LOG_ERROR, "the model requires input channels %d\n", |
|
|
|
dnn_data.channels); |
|
|
|
model_input.channels); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
if (dnn_data.dt != DNN_FLOAT && dnn_data.dt != DNN_UINT8) { |
|
|
|
if (model_input.dt != DNN_FLOAT && model_input.dt != DNN_UINT8) { |
|
|
|
av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n"); |
|
|
|
return AVERROR(EIO); |
|
|
|
} |
|
|
|
|
|
|
|
ctx->input.width = inlink->w; |
|
|
|
ctx->input.height = inlink->h; |
|
|
|
ctx->input.channels = dnn_data.channels; |
|
|
|
ctx->input.dt = dnn_data.dt; |
|
|
|
ctx->input.channels = model_input.channels; |
|
|
|
ctx->input.dt = model_input.dt; |
|
|
|
|
|
|
|
result = (ctx->model->set_input_output)(ctx->model->model, |
|
|
|
&ctx->input, ctx->model_inputname, |
|
|
|
@@ -201,28 +201,28 @@ static int config_output(AVFilterLink *outlink) |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
static int copy_from_frame_to_dnn(DNNData *dnn_data, const AVFrame *in) |
|
|
|
static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame) |
|
|
|
{ |
|
|
|
// extend this function to support more formats |
|
|
|
av_assert0(in->format == AV_PIX_FMT_RGB24 || in->format == AV_PIX_FMT_BGR24); |
|
|
|
|
|
|
|
if (dnn_data->dt == DNN_FLOAT) { |
|
|
|
float *dnn_input = dnn_data->data; |
|
|
|
for (int i = 0; i < in->height; i++) { |
|
|
|
for(int j = 0; j < in->width * 3; j++) { |
|
|
|
int k = i * in->linesize[0] + j; |
|
|
|
int t = i * in->width * 3 + j; |
|
|
|
dnn_input[t] = in->data[0][k] / 255.0f; |
|
|
|
av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24); |
|
|
|
|
|
|
|
if (dnn_input->dt == DNN_FLOAT) { |
|
|
|
float *dnn_input_data = dnn_input->data; |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
dnn_input_data[t] = frame->data[0][k] / 255.0f; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
uint8_t *dnn_input = dnn_data->data; |
|
|
|
av_assert0(dnn_data->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < in->height; i++) { |
|
|
|
for(int j = 0; j < in->width * 3; j++) { |
|
|
|
int k = i * in->linesize[0] + j; |
|
|
|
int t = i * in->width * 3 + j; |
|
|
|
dnn_input[t] = in->data[0][k]; |
|
|
|
uint8_t *dnn_input_data = dnn_input->data; |
|
|
|
av_assert0(dnn_input->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
dnn_input_data[t] = frame->data[0][k]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -230,28 +230,28 @@ static int copy_from_frame_to_dnn(DNNData *dnn_data, const AVFrame *in) |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
static int copy_from_dnn_to_frame(AVFrame *out, const DNNData *dnn_data) |
|
|
|
static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData *dnn_output) |
|
|
|
{ |
|
|
|
// extend this function to support more formats |
|
|
|
av_assert0(out->format == AV_PIX_FMT_RGB24 || out->format == AV_PIX_FMT_BGR24); |
|
|
|
|
|
|
|
if (dnn_data->dt == DNN_FLOAT) { |
|
|
|
float *dnn_output = dnn_data->data; |
|
|
|
for (int i = 0; i < out->height; i++) { |
|
|
|
for(int j = 0; j < out->width * 3; j++) { |
|
|
|
int k = i * out->linesize[0] + j; |
|
|
|
int t = i * out->width * 3 + j; |
|
|
|
out->data[0][k] = av_clip_uintp2((int)(dnn_output[t] * 255.0f), 8); |
|
|
|
av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24); |
|
|
|
|
|
|
|
if (dnn_output->dt == DNN_FLOAT) { |
|
|
|
float *dnn_output_data = dnn_output->data; |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
uint8_t *dnn_output = dnn_data->data; |
|
|
|
av_assert0(dnn_data->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < out->height; i++) { |
|
|
|
for(int j = 0; j < out->width * 3; j++) { |
|
|
|
int k = i * out->linesize[0] + j; |
|
|
|
int t = i * out->width * 3 + j; |
|
|
|
out->data[0][k] = dnn_output[t]; |
|
|
|
uint8_t *dnn_output_data = dnn_output->data; |
|
|
|
av_assert0(dnn_output->dt == DNN_UINT8); |
|
|
|
for (int i = 0; i < frame->height; i++) { |
|
|
|
for(int j = 0; j < frame->width * 3; j++) { |
|
|
|
int k = i * frame->linesize[0] + j; |
|
|
|
int t = i * frame->width * 3 + j; |
|
|
|
frame->data[0][k] = dnn_output_data[t]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|