You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1269 lines
37KB

  1. /*
  2. * This file is part of FFmpeg.
  3. *
  4. * FFmpeg is free software; you can redistribute it and/or
  5. * modify it under the terms of the GNU Lesser General Public
  6. * License as published by the Free Software Foundation; either
  7. * version 2.1 of the License, or (at your option) any later version.
  8. *
  9. * FFmpeg is distributed in the hope that it will be useful,
  10. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. * Lesser General Public License for more details.
  13. *
  14. * You should have received a copy of the GNU Lesser General Public
  15. * License along with FFmpeg; if not, write to the Free Software
  16. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  17. */
  18. #include "libavutil/avassert.h"
  19. #include "libavutil/pixfmt.h"
  20. #include "cbs.h"
  21. #include "cbs_internal.h"
  22. #include "cbs_av1.h"
  23. #include "internal.h"
  24. static int cbs_av1_read_uvlc(CodedBitstreamContext *ctx, GetBitContext *gbc,
  25. const char *name, uint32_t *write_to,
  26. uint32_t range_min, uint32_t range_max)
  27. {
  28. uint32_t zeroes, bits_value, value;
  29. int position;
  30. if (ctx->trace_enable)
  31. position = get_bits_count(gbc);
  32. zeroes = 0;
  33. while (1) {
  34. if (get_bits_left(gbc) < 1) {
  35. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid uvlc code at "
  36. "%s: bitstream ended.\n", name);
  37. return AVERROR_INVALIDDATA;
  38. }
  39. if (get_bits1(gbc))
  40. break;
  41. ++zeroes;
  42. }
  43. if (zeroes >= 32) {
  44. value = MAX_UINT_BITS(32);
  45. } else {
  46. if (get_bits_left(gbc) < zeroes) {
  47. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid uvlc code at "
  48. "%s: bitstream ended.\n", name);
  49. return AVERROR_INVALIDDATA;
  50. }
  51. bits_value = get_bits_long(gbc, zeroes);
  52. value = bits_value + (UINT32_C(1) << zeroes) - 1;
  53. }
  54. if (ctx->trace_enable) {
  55. char bits[65];
  56. int i, j, k;
  57. if (zeroes >= 32) {
  58. while (zeroes > 32) {
  59. k = FFMIN(zeroes - 32, 32);
  60. for (i = 0; i < k; i++)
  61. bits[i] = '0';
  62. bits[i] = 0;
  63. ff_cbs_trace_syntax_element(ctx, position, name,
  64. NULL, bits, 0);
  65. zeroes -= k;
  66. position += k;
  67. }
  68. }
  69. for (i = 0; i < zeroes; i++)
  70. bits[i] = '0';
  71. bits[i++] = '1';
  72. if (zeroes < 32) {
  73. for (j = 0; j < zeroes; j++)
  74. bits[i++] = (bits_value >> (zeroes - j - 1) & 1) ? '1' : '0';
  75. }
  76. bits[i] = 0;
  77. ff_cbs_trace_syntax_element(ctx, position, name,
  78. NULL, bits, value);
  79. }
  80. if (value < range_min || value > range_max) {
  81. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  82. "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
  83. name, value, range_min, range_max);
  84. return AVERROR_INVALIDDATA;
  85. }
  86. *write_to = value;
  87. return 0;
  88. }
  89. static int cbs_av1_write_uvlc(CodedBitstreamContext *ctx, PutBitContext *pbc,
  90. const char *name, uint32_t value,
  91. uint32_t range_min, uint32_t range_max)
  92. {
  93. uint32_t v;
  94. int position, zeroes;
  95. if (value < range_min || value > range_max) {
  96. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  97. "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
  98. name, value, range_min, range_max);
  99. return AVERROR_INVALIDDATA;
  100. }
  101. if (ctx->trace_enable)
  102. position = put_bits_count(pbc);
  103. zeroes = av_log2(value + 1);
  104. v = value - (1U << zeroes) + 1;
  105. put_bits(pbc, zeroes, 0);
  106. put_bits(pbc, 1, 1);
  107. put_bits(pbc, zeroes, v);
  108. if (ctx->trace_enable) {
  109. char bits[65];
  110. int i, j;
  111. i = 0;
  112. for (j = 0; j < zeroes; j++)
  113. bits[i++] = '0';
  114. bits[i++] = '1';
  115. for (j = 0; j < zeroes; j++)
  116. bits[i++] = (v >> (zeroes - j - 1) & 1) ? '1' : '0';
  117. bits[i++] = 0;
  118. ff_cbs_trace_syntax_element(ctx, position, name, NULL,
  119. bits, value);
  120. }
  121. return 0;
  122. }
  123. static int cbs_av1_read_leb128(CodedBitstreamContext *ctx, GetBitContext *gbc,
  124. const char *name, uint64_t *write_to)
  125. {
  126. uint64_t value;
  127. int position, err, i;
  128. if (ctx->trace_enable)
  129. position = get_bits_count(gbc);
  130. value = 0;
  131. for (i = 0; i < 8; i++) {
  132. int subscript[2] = { 1, i };
  133. uint32_t byte;
  134. err = ff_cbs_read_unsigned(ctx, gbc, 8, "leb128_byte[i]", subscript,
  135. &byte, 0x00, 0xff);
  136. if (err < 0)
  137. return err;
  138. value |= (uint64_t)(byte & 0x7f) << (i * 7);
  139. if (!(byte & 0x80))
  140. break;
  141. }
  142. if (value > UINT32_MAX)
  143. return AVERROR_INVALIDDATA;
  144. if (ctx->trace_enable)
  145. ff_cbs_trace_syntax_element(ctx, position, name, NULL, "", value);
  146. *write_to = value;
  147. return 0;
  148. }
  149. static int cbs_av1_write_leb128(CodedBitstreamContext *ctx, PutBitContext *pbc,
  150. const char *name, uint64_t value)
  151. {
  152. int position, err, len, i;
  153. uint8_t byte;
  154. len = (av_log2(value) + 7) / 7;
  155. if (ctx->trace_enable)
  156. position = put_bits_count(pbc);
  157. for (i = 0; i < len; i++) {
  158. int subscript[2] = { 1, i };
  159. byte = value >> (7 * i) & 0x7f;
  160. if (i < len - 1)
  161. byte |= 0x80;
  162. err = ff_cbs_write_unsigned(ctx, pbc, 8, "leb128_byte[i]", subscript,
  163. byte, 0x00, 0xff);
  164. if (err < 0)
  165. return err;
  166. }
  167. if (ctx->trace_enable)
  168. ff_cbs_trace_syntax_element(ctx, position, name, NULL, "", value);
  169. return 0;
  170. }
  171. static int cbs_av1_read_ns(CodedBitstreamContext *ctx, GetBitContext *gbc,
  172. uint32_t n, const char *name,
  173. const int *subscripts, uint32_t *write_to)
  174. {
  175. uint32_t m, v, extra_bit, value;
  176. int position, w;
  177. av_assert0(n > 0);
  178. if (ctx->trace_enable)
  179. position = get_bits_count(gbc);
  180. w = av_log2(n) + 1;
  181. m = (1 << w) - n;
  182. if (get_bits_left(gbc) < w) {
  183. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid non-symmetric value at "
  184. "%s: bitstream ended.\n", name);
  185. return AVERROR_INVALIDDATA;
  186. }
  187. if (w - 1 > 0)
  188. v = get_bits(gbc, w - 1);
  189. else
  190. v = 0;
  191. if (v < m) {
  192. value = v;
  193. } else {
  194. extra_bit = get_bits1(gbc);
  195. value = (v << 1) - m + extra_bit;
  196. }
  197. if (ctx->trace_enable) {
  198. char bits[33];
  199. int i;
  200. for (i = 0; i < w - 1; i++)
  201. bits[i] = (v >> i & 1) ? '1' : '0';
  202. if (v >= m)
  203. bits[i++] = extra_bit ? '1' : '0';
  204. bits[i] = 0;
  205. ff_cbs_trace_syntax_element(ctx, position,
  206. name, subscripts, bits, value);
  207. }
  208. *write_to = value;
  209. return 0;
  210. }
  211. static int cbs_av1_write_ns(CodedBitstreamContext *ctx, PutBitContext *pbc,
  212. uint32_t n, const char *name,
  213. const int *subscripts, uint32_t value)
  214. {
  215. uint32_t w, m, v, extra_bit;
  216. int position;
  217. if (value > n) {
  218. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  219. "%"PRIu32", but must be in [0,%"PRIu32"].\n",
  220. name, value, n);
  221. return AVERROR_INVALIDDATA;
  222. }
  223. if (ctx->trace_enable)
  224. position = put_bits_count(pbc);
  225. w = av_log2(n) + 1;
  226. m = (1 << w) - n;
  227. if (put_bits_left(pbc) < w)
  228. return AVERROR(ENOSPC);
  229. if (value < m) {
  230. v = value;
  231. put_bits(pbc, w - 1, v);
  232. } else {
  233. v = m + ((value - m) >> 1);
  234. extra_bit = (value - m) & 1;
  235. put_bits(pbc, w - 1, v);
  236. put_bits(pbc, 1, extra_bit);
  237. }
  238. if (ctx->trace_enable) {
  239. char bits[33];
  240. int i;
  241. for (i = 0; i < w - 1; i++)
  242. bits[i] = (v >> i & 1) ? '1' : '0';
  243. if (value >= m)
  244. bits[i++] = extra_bit ? '1' : '0';
  245. bits[i] = 0;
  246. ff_cbs_trace_syntax_element(ctx, position,
  247. name, subscripts, bits, value);
  248. }
  249. return 0;
  250. }
  251. static int cbs_av1_read_increment(CodedBitstreamContext *ctx, GetBitContext *gbc,
  252. uint32_t range_min, uint32_t range_max,
  253. const char *name, uint32_t *write_to)
  254. {
  255. uint32_t value;
  256. int position, i;
  257. char bits[33];
  258. av_assert0(range_min <= range_max && range_max - range_min < sizeof(bits) - 1);
  259. if (ctx->trace_enable)
  260. position = get_bits_count(gbc);
  261. for (i = 0, value = range_min; value < range_max;) {
  262. if (get_bits_left(gbc) < 1) {
  263. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid increment value at "
  264. "%s: bitstream ended.\n", name);
  265. return AVERROR_INVALIDDATA;
  266. }
  267. if (get_bits1(gbc)) {
  268. bits[i++] = '1';
  269. ++value;
  270. } else {
  271. bits[i++] = '0';
  272. break;
  273. }
  274. }
  275. if (ctx->trace_enable) {
  276. bits[i] = 0;
  277. ff_cbs_trace_syntax_element(ctx, position,
  278. name, NULL, bits, value);
  279. }
  280. *write_to = value;
  281. return 0;
  282. }
  283. static int cbs_av1_write_increment(CodedBitstreamContext *ctx, PutBitContext *pbc,
  284. uint32_t range_min, uint32_t range_max,
  285. const char *name, uint32_t value)
  286. {
  287. int len;
  288. av_assert0(range_min <= range_max && range_max - range_min < 32);
  289. if (value < range_min || value > range_max) {
  290. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  291. "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
  292. name, value, range_min, range_max);
  293. return AVERROR_INVALIDDATA;
  294. }
  295. if (value == range_max)
  296. len = range_max - range_min;
  297. else
  298. len = value - range_min + 1;
  299. if (put_bits_left(pbc) < len)
  300. return AVERROR(ENOSPC);
  301. if (ctx->trace_enable) {
  302. char bits[33];
  303. int i;
  304. for (i = 0; i < len; i++) {
  305. if (range_min + i == value)
  306. bits[i] = '0';
  307. else
  308. bits[i] = '1';
  309. }
  310. bits[i] = 0;
  311. ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
  312. name, NULL, bits, value);
  313. }
  314. if (len > 0)
  315. put_bits(pbc, len, (1 << len) - 1 - (value != range_max));
  316. return 0;
  317. }
  318. static int cbs_av1_read_subexp(CodedBitstreamContext *ctx, GetBitContext *gbc,
  319. uint32_t range_max, const char *name,
  320. const int *subscripts, uint32_t *write_to)
  321. {
  322. uint32_t value;
  323. int position, err;
  324. uint32_t max_len, len, range_offset, range_bits;
  325. if (ctx->trace_enable)
  326. position = get_bits_count(gbc);
  327. av_assert0(range_max > 0);
  328. max_len = av_log2(range_max - 1) - 3;
  329. err = cbs_av1_read_increment(ctx, gbc, 0, max_len,
  330. "subexp_more_bits", &len);
  331. if (err < 0)
  332. return err;
  333. if (len) {
  334. range_bits = 2 + len;
  335. range_offset = 1 << range_bits;
  336. } else {
  337. range_bits = 3;
  338. range_offset = 0;
  339. }
  340. if (len < max_len) {
  341. err = ff_cbs_read_unsigned(ctx, gbc, range_bits,
  342. "subexp_bits", NULL, &value,
  343. 0, MAX_UINT_BITS(range_bits));
  344. if (err < 0)
  345. return err;
  346. } else {
  347. err = cbs_av1_read_ns(ctx, gbc, range_max - range_offset,
  348. "subexp_final_bits", NULL, &value);
  349. if (err < 0)
  350. return err;
  351. }
  352. value += range_offset;
  353. if (ctx->trace_enable)
  354. ff_cbs_trace_syntax_element(ctx, position,
  355. name, subscripts, "", value);
  356. *write_to = value;
  357. return err;
  358. }
  359. static int cbs_av1_write_subexp(CodedBitstreamContext *ctx, PutBitContext *pbc,
  360. uint32_t range_max, const char *name,
  361. const int *subscripts, uint32_t value)
  362. {
  363. int position, err;
  364. uint32_t max_len, len, range_offset, range_bits;
  365. if (value > range_max) {
  366. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  367. "%"PRIu32", but must be in [0,%"PRIu32"].\n",
  368. name, value, range_max);
  369. return AVERROR_INVALIDDATA;
  370. }
  371. if (ctx->trace_enable)
  372. position = put_bits_count(pbc);
  373. av_assert0(range_max > 0);
  374. max_len = av_log2(range_max - 1) - 3;
  375. if (value < 8) {
  376. range_bits = 3;
  377. range_offset = 0;
  378. len = 0;
  379. } else {
  380. range_bits = av_log2(value);
  381. len = range_bits - 2;
  382. if (len > max_len) {
  383. // The top bin is combined with the one below it.
  384. av_assert0(len == max_len + 1);
  385. --range_bits;
  386. len = max_len;
  387. }
  388. range_offset = 1 << range_bits;
  389. }
  390. err = cbs_av1_write_increment(ctx, pbc, 0, max_len,
  391. "subexp_more_bits", len);
  392. if (err < 0)
  393. return err;
  394. if (len < max_len) {
  395. err = ff_cbs_write_unsigned(ctx, pbc, range_bits,
  396. "subexp_bits", NULL,
  397. value - range_offset,
  398. 0, MAX_UINT_BITS(range_bits));
  399. if (err < 0)
  400. return err;
  401. } else {
  402. err = cbs_av1_write_ns(ctx, pbc, range_max - range_offset,
  403. "subexp_final_bits", NULL,
  404. value - range_offset);
  405. if (err < 0)
  406. return err;
  407. }
  408. if (ctx->trace_enable)
  409. ff_cbs_trace_syntax_element(ctx, position,
  410. name, subscripts, "", value);
  411. return err;
  412. }
  413. static int cbs_av1_tile_log2(int blksize, int target)
  414. {
  415. int k;
  416. for (k = 0; (blksize << k) < target; k++);
  417. return k;
  418. }
  419. static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
  420. unsigned int a, unsigned int b)
  421. {
  422. unsigned int diff, m;
  423. if (!seq->enable_order_hint)
  424. return 0;
  425. diff = a - b;
  426. m = 1 << seq->order_hint_bits_minus_1;
  427. diff = (diff & (m - 1)) - (diff & m);
  428. return diff;
  429. }
  430. static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
  431. {
  432. GetBitContext tmp = *gbc;
  433. size_t size = 0;
  434. for (int i = 0; get_bits_left(&tmp) >= 8; i++) {
  435. if (get_bits(&tmp, 8))
  436. size = i;
  437. }
  438. return size;
  439. }
  440. #define HEADER(name) do { \
  441. ff_cbs_trace_header(ctx, name); \
  442. } while (0)
  443. #define CHECK(call) do { \
  444. err = (call); \
  445. if (err < 0) \
  446. return err; \
  447. } while (0)
  448. #define FUNC_NAME(rw, codec, name) cbs_ ## codec ## _ ## rw ## _ ## name
  449. #define FUNC_AV1(rw, name) FUNC_NAME(rw, av1, name)
  450. #define FUNC(name) FUNC_AV1(READWRITE, name)
  451. #define SUBSCRIPTS(subs, ...) (subs > 0 ? ((int[subs + 1]){ subs, __VA_ARGS__ }) : NULL)
  452. #define fb(width, name) \
  453. xf(width, name, current->name, 0, MAX_UINT_BITS(width), 0, )
  454. #define fc(width, name, range_min, range_max) \
  455. xf(width, name, current->name, range_min, range_max, 0, )
  456. #define flag(name) fb(1, name)
  457. #define su(width, name) \
  458. xsu(width, name, current->name, 0, )
  459. #define fbs(width, name, subs, ...) \
  460. xf(width, name, current->name, 0, MAX_UINT_BITS(width), subs, __VA_ARGS__)
  461. #define fcs(width, name, range_min, range_max, subs, ...) \
  462. xf(width, name, current->name, range_min, range_max, subs, __VA_ARGS__)
  463. #define flags(name, subs, ...) \
  464. xf(1, name, current->name, 0, 1, subs, __VA_ARGS__)
  465. #define sus(width, name, subs, ...) \
  466. xsu(width, name, current->name, subs, __VA_ARGS__)
  467. #define fixed(width, name, value) do { \
  468. av_unused uint32_t fixed_value = value; \
  469. xf(width, name, fixed_value, value, value, 0, ); \
  470. } while (0)
  471. #define READ
  472. #define READWRITE read
  473. #define RWContext GetBitContext
  474. #define xf(width, name, var, range_min, range_max, subs, ...) do { \
  475. uint32_t value; \
  476. CHECK(ff_cbs_read_unsigned(ctx, rw, width, #name, \
  477. SUBSCRIPTS(subs, __VA_ARGS__), \
  478. &value, range_min, range_max)); \
  479. var = value; \
  480. } while (0)
  481. #define xsu(width, name, var, subs, ...) do { \
  482. int32_t value; \
  483. CHECK(ff_cbs_read_signed(ctx, rw, width, #name, \
  484. SUBSCRIPTS(subs, __VA_ARGS__), &value, \
  485. MIN_INT_BITS(width), \
  486. MAX_INT_BITS(width))); \
  487. var = value; \
  488. } while (0)
  489. #define uvlc(name, range_min, range_max) do { \
  490. uint32_t value; \
  491. CHECK(cbs_av1_read_uvlc(ctx, rw, #name, \
  492. &value, range_min, range_max)); \
  493. current->name = value; \
  494. } while (0)
  495. #define ns(max_value, name, subs, ...) do { \
  496. uint32_t value; \
  497. CHECK(cbs_av1_read_ns(ctx, rw, max_value, #name, \
  498. SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
  499. current->name = value; \
  500. } while (0)
  501. #define increment(name, min, max) do { \
  502. uint32_t value; \
  503. CHECK(cbs_av1_read_increment(ctx, rw, min, max, #name, &value)); \
  504. current->name = value; \
  505. } while (0)
  506. #define subexp(name, max, subs, ...) do { \
  507. uint32_t value; \
  508. CHECK(cbs_av1_read_subexp(ctx, rw, max, #name, \
  509. SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
  510. current->name = value; \
  511. } while (0)
  512. #define delta_q(name) do { \
  513. uint8_t delta_coded; \
  514. int8_t delta_q; \
  515. xf(1, name.delta_coded, delta_coded, 0, 1, 0, ); \
  516. if (delta_coded) \
  517. xsu(1 + 6, name.delta_q, delta_q, 0, ); \
  518. else \
  519. delta_q = 0; \
  520. current->name = delta_q; \
  521. } while (0)
  522. #define leb128(name) do { \
  523. uint64_t value; \
  524. CHECK(cbs_av1_read_leb128(ctx, rw, #name, &value)); \
  525. current->name = value; \
  526. } while (0)
  527. #define infer(name, value) do { \
  528. current->name = value; \
  529. } while (0)
  530. #define byte_alignment(rw) (get_bits_count(rw) % 8)
  531. #include "cbs_av1_syntax_template.c"
  532. #undef READ
  533. #undef READWRITE
  534. #undef RWContext
  535. #undef xf
  536. #undef xsu
  537. #undef uvlc
  538. #undef ns
  539. #undef increment
  540. #undef subexp
  541. #undef delta_q
  542. #undef leb128
  543. #undef infer
  544. #undef byte_alignment
  545. #define WRITE
  546. #define READWRITE write
  547. #define RWContext PutBitContext
  548. #define xf(width, name, var, range_min, range_max, subs, ...) do { \
  549. CHECK(ff_cbs_write_unsigned(ctx, rw, width, #name, \
  550. SUBSCRIPTS(subs, __VA_ARGS__), \
  551. var, range_min, range_max)); \
  552. } while (0)
  553. #define xsu(width, name, var, subs, ...) do { \
  554. CHECK(ff_cbs_write_signed(ctx, rw, width, #name, \
  555. SUBSCRIPTS(subs, __VA_ARGS__), var, \
  556. MIN_INT_BITS(width), \
  557. MAX_INT_BITS(width))); \
  558. } while (0)
  559. #define uvlc(name, range_min, range_max) do { \
  560. CHECK(cbs_av1_write_uvlc(ctx, rw, #name, current->name, \
  561. range_min, range_max)); \
  562. } while (0)
  563. #define ns(max_value, name, subs, ...) do { \
  564. CHECK(cbs_av1_write_ns(ctx, rw, max_value, #name, \
  565. SUBSCRIPTS(subs, __VA_ARGS__), \
  566. current->name)); \
  567. } while (0)
  568. #define increment(name, min, max) do { \
  569. CHECK(cbs_av1_write_increment(ctx, rw, min, max, #name, \
  570. current->name)); \
  571. } while (0)
  572. #define subexp(name, max, subs, ...) do { \
  573. CHECK(cbs_av1_write_subexp(ctx, rw, max, #name, \
  574. SUBSCRIPTS(subs, __VA_ARGS__), \
  575. current->name)); \
  576. } while (0)
  577. #define delta_q(name) do { \
  578. xf(1, name.delta_coded, current->name != 0, 0, 1, 0, ); \
  579. if (current->name) \
  580. xsu(1 + 6, name.delta_q, current->name, 0, ); \
  581. } while (0)
  582. #define leb128(name) do { \
  583. CHECK(cbs_av1_write_leb128(ctx, rw, #name, current->name)); \
  584. } while (0)
  585. #define infer(name, value) do { \
  586. if (current->name != (value)) { \
  587. av_log(ctx->log_ctx, AV_LOG_ERROR, \
  588. "%s does not match inferred value: " \
  589. "%"PRId64", but should be %"PRId64".\n", \
  590. #name, (int64_t)current->name, (int64_t)(value)); \
  591. return AVERROR_INVALIDDATA; \
  592. } \
  593. } while (0)
  594. #define byte_alignment(rw) (put_bits_count(rw) % 8)
  595. #include "cbs_av1_syntax_template.c"
  596. #undef WRITE
  597. #undef READWRITE
  598. #undef RWContext
  599. #undef xf
  600. #undef xsu
  601. #undef uvlc
  602. #undef ns
  603. #undef increment
  604. #undef subexp
  605. #undef delta_q
  606. #undef leb128
  607. #undef infer
  608. #undef byte_alignment
  609. static int cbs_av1_split_fragment(CodedBitstreamContext *ctx,
  610. CodedBitstreamFragment *frag,
  611. int header)
  612. {
  613. GetBitContext gbc;
  614. uint8_t *data;
  615. size_t size;
  616. uint64_t obu_length;
  617. int pos, err, trace;
  618. // Don't include this parsing in trace output.
  619. trace = ctx->trace_enable;
  620. ctx->trace_enable = 0;
  621. data = frag->data;
  622. size = frag->data_size;
  623. if (INT_MAX / 8 < size) {
  624. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid fragment: "
  625. "too large (%"SIZE_SPECIFIER" bytes).\n", size);
  626. err = AVERROR_INVALIDDATA;
  627. goto fail;
  628. }
  629. while (size > 0) {
  630. AV1RawOBUHeader header;
  631. uint64_t obu_size;
  632. init_get_bits(&gbc, data, 8 * size);
  633. err = cbs_av1_read_obu_header(ctx, &gbc, &header);
  634. if (err < 0)
  635. goto fail;
  636. if (header.obu_has_size_field) {
  637. if (get_bits_left(&gbc) < 8) {
  638. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU: fragment "
  639. "too short (%"SIZE_SPECIFIER" bytes).\n", size);
  640. err = AVERROR_INVALIDDATA;
  641. goto fail;
  642. }
  643. err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
  644. if (err < 0)
  645. goto fail;
  646. } else
  647. obu_size = size - 1 - header.obu_extension_flag;
  648. pos = get_bits_count(&gbc);
  649. av_assert0(pos % 8 == 0 && pos / 8 <= size);
  650. obu_length = pos / 8 + obu_size;
  651. if (size < obu_length) {
  652. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
  653. "%"PRIu64", but only %"SIZE_SPECIFIER" bytes remaining in fragment.\n",
  654. obu_length, size);
  655. err = AVERROR_INVALIDDATA;
  656. goto fail;
  657. }
  658. err = ff_cbs_insert_unit_data(frag, -1, header.obu_type,
  659. data, obu_length, frag->data_ref);
  660. if (err < 0)
  661. goto fail;
  662. data += obu_length;
  663. size -= obu_length;
  664. }
  665. err = 0;
  666. fail:
  667. ctx->trace_enable = trace;
  668. return err;
  669. }
  670. static int cbs_av1_ref_tile_data(CodedBitstreamContext *ctx,
  671. CodedBitstreamUnit *unit,
  672. GetBitContext *gbc,
  673. AV1RawTileData *td)
  674. {
  675. int pos;
  676. pos = get_bits_count(gbc);
  677. if (pos >= 8 * unit->data_size) {
  678. av_log(ctx->log_ctx, AV_LOG_ERROR, "Bitstream ended before "
  679. "any data in tile group (%d bits read).\n", pos);
  680. return AVERROR_INVALIDDATA;
  681. }
  682. // Must be byte-aligned at this point.
  683. av_assert0(pos % 8 == 0);
  684. td->data_ref = av_buffer_ref(unit->data_ref);
  685. if (!td->data_ref)
  686. return AVERROR(ENOMEM);
  687. td->data = unit->data + pos / 8;
  688. td->data_size = unit->data_size - pos / 8;
  689. return 0;
  690. }
  691. static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
  692. CodedBitstreamUnit *unit)
  693. {
  694. CodedBitstreamAV1Context *priv = ctx->priv_data;
  695. AV1RawOBU *obu;
  696. GetBitContext gbc;
  697. int err, start_pos, end_pos;
  698. err = ff_cbs_alloc_unit_content2(ctx, unit);
  699. if (err < 0)
  700. return err;
  701. obu = unit->content;
  702. err = init_get_bits(&gbc, unit->data, 8 * unit->data_size);
  703. if (err < 0)
  704. return err;
  705. err = cbs_av1_read_obu_header(ctx, &gbc, &obu->header);
  706. if (err < 0)
  707. return err;
  708. av_assert0(obu->header.obu_type == unit->type);
  709. if (obu->header.obu_has_size_field) {
  710. uint64_t obu_size;
  711. err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
  712. if (err < 0)
  713. return err;
  714. obu->obu_size = obu_size;
  715. } else {
  716. if (unit->data_size < 1 + obu->header.obu_extension_flag) {
  717. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
  718. "unit too short (%"SIZE_SPECIFIER").\n", unit->data_size);
  719. return AVERROR_INVALIDDATA;
  720. }
  721. obu->obu_size = unit->data_size - 1 - obu->header.obu_extension_flag;
  722. }
  723. start_pos = get_bits_count(&gbc);
  724. if (obu->header.obu_extension_flag) {
  725. if (obu->header.obu_type != AV1_OBU_SEQUENCE_HEADER &&
  726. obu->header.obu_type != AV1_OBU_TEMPORAL_DELIMITER &&
  727. priv->operating_point_idc) {
  728. int in_temporal_layer =
  729. (priv->operating_point_idc >> priv->temporal_id ) & 1;
  730. int in_spatial_layer =
  731. (priv->operating_point_idc >> (priv->spatial_id + 8)) & 1;
  732. if (!in_temporal_layer || !in_spatial_layer) {
  733. // Decoding will drop this OBU at this operating point.
  734. }
  735. }
  736. }
  737. switch (obu->header.obu_type) {
  738. case AV1_OBU_SEQUENCE_HEADER:
  739. {
  740. err = cbs_av1_read_sequence_header_obu(ctx, &gbc,
  741. &obu->obu.sequence_header);
  742. if (err < 0)
  743. return err;
  744. av_buffer_unref(&priv->sequence_header_ref);
  745. priv->sequence_header = NULL;
  746. priv->sequence_header_ref = av_buffer_ref(unit->content_ref);
  747. if (!priv->sequence_header_ref)
  748. return AVERROR(ENOMEM);
  749. priv->sequence_header = &obu->obu.sequence_header;
  750. }
  751. break;
  752. case AV1_OBU_TEMPORAL_DELIMITER:
  753. {
  754. err = cbs_av1_read_temporal_delimiter_obu(ctx, &gbc);
  755. if (err < 0)
  756. return err;
  757. }
  758. break;
  759. case AV1_OBU_FRAME_HEADER:
  760. case AV1_OBU_REDUNDANT_FRAME_HEADER:
  761. {
  762. err = cbs_av1_read_frame_header_obu(ctx, &gbc,
  763. &obu->obu.frame_header,
  764. obu->header.obu_type ==
  765. AV1_OBU_REDUNDANT_FRAME_HEADER,
  766. unit->data_ref);
  767. if (err < 0)
  768. return err;
  769. }
  770. break;
  771. case AV1_OBU_TILE_GROUP:
  772. {
  773. err = cbs_av1_read_tile_group_obu(ctx, &gbc,
  774. &obu->obu.tile_group);
  775. if (err < 0)
  776. return err;
  777. err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
  778. &obu->obu.tile_group.tile_data);
  779. if (err < 0)
  780. return err;
  781. }
  782. break;
  783. case AV1_OBU_FRAME:
  784. {
  785. err = cbs_av1_read_frame_obu(ctx, &gbc, &obu->obu.frame,
  786. unit->data_ref);
  787. if (err < 0)
  788. return err;
  789. err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
  790. &obu->obu.frame.tile_group.tile_data);
  791. if (err < 0)
  792. return err;
  793. }
  794. break;
  795. case AV1_OBU_TILE_LIST:
  796. {
  797. err = cbs_av1_read_tile_list_obu(ctx, &gbc,
  798. &obu->obu.tile_list);
  799. if (err < 0)
  800. return err;
  801. err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
  802. &obu->obu.tile_list.tile_data);
  803. if (err < 0)
  804. return err;
  805. }
  806. break;
  807. case AV1_OBU_METADATA:
  808. {
  809. err = cbs_av1_read_metadata_obu(ctx, &gbc, &obu->obu.metadata);
  810. if (err < 0)
  811. return err;
  812. }
  813. break;
  814. case AV1_OBU_PADDING:
  815. {
  816. err = cbs_av1_read_padding_obu(ctx, &gbc, &obu->obu.padding);
  817. if (err < 0)
  818. return err;
  819. }
  820. break;
  821. default:
  822. return AVERROR(ENOSYS);
  823. }
  824. end_pos = get_bits_count(&gbc);
  825. av_assert0(end_pos <= unit->data_size * 8);
  826. if (obu->obu_size > 0 &&
  827. obu->header.obu_type != AV1_OBU_TILE_GROUP &&
  828. obu->header.obu_type != AV1_OBU_TILE_LIST &&
  829. obu->header.obu_type != AV1_OBU_FRAME) {
  830. int nb_bits = obu->obu_size * 8 + start_pos - end_pos;
  831. if (nb_bits <= 0)
  832. return AVERROR_INVALIDDATA;
  833. err = cbs_av1_read_trailing_bits(ctx, &gbc, nb_bits);
  834. if (err < 0)
  835. return err;
  836. }
  837. return 0;
  838. }
  839. static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
  840. CodedBitstreamUnit *unit,
  841. PutBitContext *pbc)
  842. {
  843. CodedBitstreamAV1Context *priv = ctx->priv_data;
  844. AV1RawOBU *obu = unit->content;
  845. PutBitContext pbc_tmp;
  846. AV1RawTileData *td;
  847. size_t header_size;
  848. int err, start_pos, end_pos, data_pos;
  849. // OBUs in the normal bitstream format must contain a size field
  850. // in every OBU (in annex B it is optional, but we don't support
  851. // writing that).
  852. obu->header.obu_has_size_field = 1;
  853. err = cbs_av1_write_obu_header(ctx, pbc, &obu->header);
  854. if (err < 0)
  855. return err;
  856. if (obu->header.obu_has_size_field) {
  857. pbc_tmp = *pbc;
  858. // Add space for the size field to fill later.
  859. put_bits32(pbc, 0);
  860. put_bits32(pbc, 0);
  861. }
  862. td = NULL;
  863. start_pos = put_bits_count(pbc);
  864. switch (obu->header.obu_type) {
  865. case AV1_OBU_SEQUENCE_HEADER:
  866. {
  867. err = cbs_av1_write_sequence_header_obu(ctx, pbc,
  868. &obu->obu.sequence_header);
  869. if (err < 0)
  870. return err;
  871. av_buffer_unref(&priv->sequence_header_ref);
  872. priv->sequence_header = NULL;
  873. priv->sequence_header_ref = av_buffer_ref(unit->content_ref);
  874. if (!priv->sequence_header_ref)
  875. return AVERROR(ENOMEM);
  876. priv->sequence_header = &obu->obu.sequence_header;
  877. }
  878. break;
  879. case AV1_OBU_TEMPORAL_DELIMITER:
  880. {
  881. err = cbs_av1_write_temporal_delimiter_obu(ctx, pbc);
  882. if (err < 0)
  883. return err;
  884. }
  885. break;
  886. case AV1_OBU_FRAME_HEADER:
  887. case AV1_OBU_REDUNDANT_FRAME_HEADER:
  888. {
  889. err = cbs_av1_write_frame_header_obu(ctx, pbc,
  890. &obu->obu.frame_header,
  891. obu->header.obu_type ==
  892. AV1_OBU_REDUNDANT_FRAME_HEADER,
  893. NULL);
  894. if (err < 0)
  895. return err;
  896. }
  897. break;
  898. case AV1_OBU_TILE_GROUP:
  899. {
  900. err = cbs_av1_write_tile_group_obu(ctx, pbc,
  901. &obu->obu.tile_group);
  902. if (err < 0)
  903. return err;
  904. td = &obu->obu.tile_group.tile_data;
  905. }
  906. break;
  907. case AV1_OBU_FRAME:
  908. {
  909. err = cbs_av1_write_frame_obu(ctx, pbc, &obu->obu.frame, NULL);
  910. if (err < 0)
  911. return err;
  912. td = &obu->obu.frame.tile_group.tile_data;
  913. }
  914. break;
  915. case AV1_OBU_TILE_LIST:
  916. {
  917. err = cbs_av1_write_tile_list_obu(ctx, pbc, &obu->obu.tile_list);
  918. if (err < 0)
  919. return err;
  920. td = &obu->obu.tile_list.tile_data;
  921. }
  922. break;
  923. case AV1_OBU_METADATA:
  924. {
  925. err = cbs_av1_write_metadata_obu(ctx, pbc, &obu->obu.metadata);
  926. if (err < 0)
  927. return err;
  928. }
  929. break;
  930. case AV1_OBU_PADDING:
  931. {
  932. err = cbs_av1_write_padding_obu(ctx, pbc, &obu->obu.padding);
  933. if (err < 0)
  934. return err;
  935. }
  936. break;
  937. default:
  938. return AVERROR(ENOSYS);
  939. }
  940. end_pos = put_bits_count(pbc);
  941. header_size = (end_pos - start_pos + 7) / 8;
  942. if (td) {
  943. obu->obu_size = header_size + td->data_size;
  944. } else if (header_size > 0) {
  945. // Add trailing bits and recalculate.
  946. err = cbs_av1_write_trailing_bits(ctx, pbc, 8 - end_pos % 8);
  947. if (err < 0)
  948. return err;
  949. end_pos = put_bits_count(pbc);
  950. obu->obu_size = header_size = (end_pos - start_pos + 7) / 8;
  951. } else {
  952. // Empty OBU.
  953. obu->obu_size = 0;
  954. }
  955. end_pos = put_bits_count(pbc);
  956. // Must now be byte-aligned.
  957. av_assert0(end_pos % 8 == 0);
  958. flush_put_bits(pbc);
  959. start_pos /= 8;
  960. end_pos /= 8;
  961. *pbc = pbc_tmp;
  962. err = cbs_av1_write_leb128(ctx, pbc, "obu_size", obu->obu_size);
  963. if (err < 0)
  964. return err;
  965. data_pos = put_bits_count(pbc) / 8;
  966. flush_put_bits(pbc);
  967. av_assert0(data_pos <= start_pos);
  968. if (8 * obu->obu_size > put_bits_left(pbc))
  969. return AVERROR(ENOSPC);
  970. if (obu->obu_size > 0) {
  971. memmove(pbc->buf + data_pos,
  972. pbc->buf + start_pos, header_size);
  973. skip_put_bytes(pbc, header_size);
  974. if (td) {
  975. memcpy(pbc->buf + data_pos + header_size,
  976. td->data, td->data_size);
  977. skip_put_bytes(pbc, td->data_size);
  978. }
  979. }
  980. // OBU data must be byte-aligned.
  981. av_assert0(put_bits_count(pbc) % 8 == 0);
  982. return 0;
  983. }
  984. static int cbs_av1_assemble_fragment(CodedBitstreamContext *ctx,
  985. CodedBitstreamFragment *frag)
  986. {
  987. size_t size, pos;
  988. int i;
  989. size = 0;
  990. for (i = 0; i < frag->nb_units; i++)
  991. size += frag->units[i].data_size;
  992. frag->data_ref = av_buffer_alloc(size + AV_INPUT_BUFFER_PADDING_SIZE);
  993. if (!frag->data_ref)
  994. return AVERROR(ENOMEM);
  995. frag->data = frag->data_ref->data;
  996. memset(frag->data + size, 0, AV_INPUT_BUFFER_PADDING_SIZE);
  997. pos = 0;
  998. for (i = 0; i < frag->nb_units; i++) {
  999. memcpy(frag->data + pos, frag->units[i].data,
  1000. frag->units[i].data_size);
  1001. pos += frag->units[i].data_size;
  1002. }
  1003. av_assert0(pos == size);
  1004. frag->data_size = size;
  1005. return 0;
  1006. }
  1007. static void cbs_av1_flush(CodedBitstreamContext *ctx)
  1008. {
  1009. CodedBitstreamAV1Context *priv = ctx->priv_data;
  1010. av_buffer_unref(&priv->frame_header_ref);
  1011. priv->sequence_header = NULL;
  1012. priv->frame_header = NULL;
  1013. memset(priv->ref, 0, sizeof(priv->ref));
  1014. priv->operating_point_idc = 0;
  1015. priv->seen_frame_header = 0;
  1016. }
  1017. static void cbs_av1_close(CodedBitstreamContext *ctx)
  1018. {
  1019. CodedBitstreamAV1Context *priv = ctx->priv_data;
  1020. av_buffer_unref(&priv->sequence_header_ref);
  1021. av_buffer_unref(&priv->frame_header_ref);
  1022. }
  1023. static void cbs_av1_free_metadata(void *unit, uint8_t *content)
  1024. {
  1025. AV1RawOBU *obu = (AV1RawOBU*)content;
  1026. AV1RawMetadata *md;
  1027. av_assert0(obu->header.obu_type == AV1_OBU_METADATA);
  1028. md = &obu->obu.metadata;
  1029. switch (md->metadata_type) {
  1030. case AV1_METADATA_TYPE_ITUT_T35:
  1031. av_buffer_unref(&md->metadata.itut_t35.payload_ref);
  1032. break;
  1033. }
  1034. }
  1035. static const CodedBitstreamUnitTypeDescriptor cbs_av1_unit_types[] = {
  1036. CBS_UNIT_TYPE_POD(AV1_OBU_SEQUENCE_HEADER, AV1RawOBU),
  1037. CBS_UNIT_TYPE_POD(AV1_OBU_TEMPORAL_DELIMITER, AV1RawOBU),
  1038. CBS_UNIT_TYPE_POD(AV1_OBU_FRAME_HEADER, AV1RawOBU),
  1039. CBS_UNIT_TYPE_POD(AV1_OBU_REDUNDANT_FRAME_HEADER, AV1RawOBU),
  1040. CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_TILE_GROUP, AV1RawOBU,
  1041. obu.tile_group.tile_data.data),
  1042. CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_FRAME, AV1RawOBU,
  1043. obu.frame.tile_group.tile_data.data),
  1044. CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_TILE_LIST, AV1RawOBU,
  1045. obu.tile_list.tile_data.data),
  1046. CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_PADDING, AV1RawOBU,
  1047. obu.padding.payload),
  1048. CBS_UNIT_TYPE_COMPLEX(AV1_OBU_METADATA, AV1RawOBU,
  1049. &cbs_av1_free_metadata),
  1050. CBS_UNIT_TYPE_END_OF_LIST
  1051. };
  1052. const CodedBitstreamType ff_cbs_type_av1 = {
  1053. .codec_id = AV_CODEC_ID_AV1,
  1054. .priv_data_size = sizeof(CodedBitstreamAV1Context),
  1055. .unit_types = cbs_av1_unit_types,
  1056. .split_fragment = &cbs_av1_split_fragment,
  1057. .read_unit = &cbs_av1_read_unit,
  1058. .write_unit = &cbs_av1_write_obu,
  1059. .assemble_fragment = &cbs_av1_assemble_fragment,
  1060. .flush = &cbs_av1_flush,
  1061. .close = &cbs_av1_close,
  1062. };