You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1317 lines
38KB

  1. /*
  2. * This file is part of FFmpeg.
  3. *
  4. * FFmpeg is free software; you can redistribute it and/or
  5. * modify it under the terms of the GNU Lesser General Public
  6. * License as published by the Free Software Foundation; either
  7. * version 2.1 of the License, or (at your option) any later version.
  8. *
  9. * FFmpeg is distributed in the hope that it will be useful,
  10. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. * Lesser General Public License for more details.
  13. *
  14. * You should have received a copy of the GNU Lesser General Public
  15. * License along with FFmpeg; if not, write to the Free Software
  16. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  17. */
  18. #include "libavutil/avassert.h"
  19. #include "libavutil/pixfmt.h"
  20. #include "cbs.h"
  21. #include "cbs_internal.h"
  22. #include "cbs_av1.h"
  23. #include "internal.h"
  24. static int cbs_av1_read_uvlc(CodedBitstreamContext *ctx, GetBitContext *gbc,
  25. const char *name, uint32_t *write_to,
  26. uint32_t range_min, uint32_t range_max)
  27. {
  28. uint32_t zeroes, bits_value, value;
  29. int position;
  30. if (ctx->trace_enable)
  31. position = get_bits_count(gbc);
  32. zeroes = 0;
  33. while (1) {
  34. if (get_bits_left(gbc) < 1) {
  35. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid uvlc code at "
  36. "%s: bitstream ended.\n", name);
  37. return AVERROR_INVALIDDATA;
  38. }
  39. if (get_bits1(gbc))
  40. break;
  41. ++zeroes;
  42. }
  43. if (zeroes >= 32) {
  44. value = MAX_UINT_BITS(32);
  45. } else {
  46. if (get_bits_left(gbc) < zeroes) {
  47. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid uvlc code at "
  48. "%s: bitstream ended.\n", name);
  49. return AVERROR_INVALIDDATA;
  50. }
  51. bits_value = get_bits_long(gbc, zeroes);
  52. value = bits_value + (UINT32_C(1) << zeroes) - 1;
  53. }
  54. if (ctx->trace_enable) {
  55. char bits[65];
  56. int i, j, k;
  57. if (zeroes >= 32) {
  58. while (zeroes > 32) {
  59. k = FFMIN(zeroes - 32, 32);
  60. for (i = 0; i < k; i++)
  61. bits[i] = '0';
  62. bits[i] = 0;
  63. ff_cbs_trace_syntax_element(ctx, position, name,
  64. NULL, bits, 0);
  65. zeroes -= k;
  66. position += k;
  67. }
  68. }
  69. for (i = 0; i < zeroes; i++)
  70. bits[i] = '0';
  71. bits[i++] = '1';
  72. if (zeroes < 32) {
  73. for (j = 0; j < zeroes; j++)
  74. bits[i++] = (bits_value >> (zeroes - j - 1) & 1) ? '1' : '0';
  75. }
  76. bits[i] = 0;
  77. ff_cbs_trace_syntax_element(ctx, position, name,
  78. NULL, bits, value);
  79. }
  80. if (value < range_min || value > range_max) {
  81. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  82. "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
  83. name, value, range_min, range_max);
  84. return AVERROR_INVALIDDATA;
  85. }
  86. *write_to = value;
  87. return 0;
  88. }
  89. static int cbs_av1_write_uvlc(CodedBitstreamContext *ctx, PutBitContext *pbc,
  90. const char *name, uint32_t value,
  91. uint32_t range_min, uint32_t range_max)
  92. {
  93. uint32_t v;
  94. int position, zeroes;
  95. if (value < range_min || value > range_max) {
  96. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  97. "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
  98. name, value, range_min, range_max);
  99. return AVERROR_INVALIDDATA;
  100. }
  101. if (ctx->trace_enable)
  102. position = put_bits_count(pbc);
  103. if (value == 0) {
  104. zeroes = 0;
  105. put_bits(pbc, 1, 1);
  106. } else {
  107. zeroes = av_log2(value + 1);
  108. v = value - (1 << zeroes) + 1;
  109. put_bits(pbc, zeroes + 1, 1);
  110. put_bits(pbc, zeroes, v);
  111. }
  112. if (ctx->trace_enable) {
  113. char bits[65];
  114. int i, j;
  115. i = 0;
  116. for (j = 0; j < zeroes; j++)
  117. bits[i++] = '0';
  118. bits[i++] = '1';
  119. for (j = 0; j < zeroes; j++)
  120. bits[i++] = (v >> (zeroes - j - 1) & 1) ? '1' : '0';
  121. bits[i++] = 0;
  122. ff_cbs_trace_syntax_element(ctx, position, name, NULL,
  123. bits, value);
  124. }
  125. return 0;
  126. }
  127. static int cbs_av1_read_leb128(CodedBitstreamContext *ctx, GetBitContext *gbc,
  128. const char *name, uint64_t *write_to)
  129. {
  130. uint64_t value;
  131. int position, err, i;
  132. if (ctx->trace_enable)
  133. position = get_bits_count(gbc);
  134. value = 0;
  135. for (i = 0; i < 8; i++) {
  136. int subscript[2] = { 1, i };
  137. uint32_t byte;
  138. err = ff_cbs_read_unsigned(ctx, gbc, 8, "leb128_byte[i]", subscript,
  139. &byte, 0x00, 0xff);
  140. if (err < 0)
  141. return err;
  142. value |= (uint64_t)(byte & 0x7f) << (i * 7);
  143. if (!(byte & 0x80))
  144. break;
  145. }
  146. if (ctx->trace_enable)
  147. ff_cbs_trace_syntax_element(ctx, position, name, NULL, "", value);
  148. *write_to = value;
  149. return 0;
  150. }
  151. static int cbs_av1_write_leb128(CodedBitstreamContext *ctx, PutBitContext *pbc,
  152. const char *name, uint64_t value)
  153. {
  154. int position, err, len, i;
  155. uint8_t byte;
  156. len = (av_log2(value) + 7) / 7;
  157. if (ctx->trace_enable)
  158. position = put_bits_count(pbc);
  159. for (i = 0; i < len; i++) {
  160. int subscript[2] = { 1, i };
  161. byte = value >> (7 * i) & 0x7f;
  162. if (i < len - 1)
  163. byte |= 0x80;
  164. err = ff_cbs_write_unsigned(ctx, pbc, 8, "leb128_byte[i]", subscript,
  165. byte, 0x00, 0xff);
  166. if (err < 0)
  167. return err;
  168. }
  169. if (ctx->trace_enable)
  170. ff_cbs_trace_syntax_element(ctx, position, name, NULL, "", value);
  171. return 0;
  172. }
  173. static int cbs_av1_read_ns(CodedBitstreamContext *ctx, GetBitContext *gbc,
  174. uint32_t n, const char *name,
  175. const int *subscripts, uint32_t *write_to)
  176. {
  177. uint32_t w, m, v, extra_bit, value;
  178. int position;
  179. av_assert0(n > 0);
  180. if (ctx->trace_enable)
  181. position = get_bits_count(gbc);
  182. w = av_log2(n) + 1;
  183. m = (1 << w) - n;
  184. if (get_bits_left(gbc) < w) {
  185. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid non-symmetric value at "
  186. "%s: bitstream ended.\n", name);
  187. return AVERROR_INVALIDDATA;
  188. }
  189. if (w - 1 > 0)
  190. v = get_bits(gbc, w - 1);
  191. else
  192. v = 0;
  193. if (v < m) {
  194. value = v;
  195. } else {
  196. extra_bit = get_bits1(gbc);
  197. value = (v << 1) - m + extra_bit;
  198. }
  199. if (ctx->trace_enable) {
  200. char bits[33];
  201. int i;
  202. for (i = 0; i < w - 1; i++)
  203. bits[i] = (v >> i & 1) ? '1' : '0';
  204. if (v >= m)
  205. bits[i++] = extra_bit ? '1' : '0';
  206. bits[i] = 0;
  207. ff_cbs_trace_syntax_element(ctx, position,
  208. name, subscripts, bits, value);
  209. }
  210. *write_to = value;
  211. return 0;
  212. }
  213. static int cbs_av1_write_ns(CodedBitstreamContext *ctx, PutBitContext *pbc,
  214. uint32_t n, const char *name,
  215. const int *subscripts, uint32_t value)
  216. {
  217. uint32_t w, m, v, extra_bit;
  218. int position;
  219. if (value > n) {
  220. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  221. "%"PRIu32", but must be in [0,%"PRIu32"].\n",
  222. name, value, n);
  223. return AVERROR_INVALIDDATA;
  224. }
  225. if (ctx->trace_enable)
  226. position = put_bits_count(pbc);
  227. w = av_log2(n) + 1;
  228. m = (1 << w) - n;
  229. if (put_bits_left(pbc) < w)
  230. return AVERROR(ENOSPC);
  231. if (value < m) {
  232. v = value;
  233. put_bits(pbc, w - 1, v);
  234. } else {
  235. v = m + ((value - m) >> 1);
  236. extra_bit = (value - m) & 1;
  237. put_bits(pbc, w - 1, v);
  238. put_bits(pbc, 1, extra_bit);
  239. }
  240. if (ctx->trace_enable) {
  241. char bits[33];
  242. int i;
  243. for (i = 0; i < w - 1; i++)
  244. bits[i] = (v >> i & 1) ? '1' : '0';
  245. if (value >= m)
  246. bits[i++] = extra_bit ? '1' : '0';
  247. bits[i] = 0;
  248. ff_cbs_trace_syntax_element(ctx, position,
  249. name, subscripts, bits, value);
  250. }
  251. return 0;
  252. }
  253. static int cbs_av1_read_increment(CodedBitstreamContext *ctx, GetBitContext *gbc,
  254. uint32_t range_min, uint32_t range_max,
  255. const char *name, uint32_t *write_to)
  256. {
  257. uint32_t value;
  258. int position, i;
  259. char bits[33];
  260. av_assert0(range_min <= range_max && range_max - range_min < sizeof(bits) - 1);
  261. if (ctx->trace_enable)
  262. position = get_bits_count(gbc);
  263. for (i = 0, value = range_min; value < range_max;) {
  264. if (get_bits_left(gbc) < 1) {
  265. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid increment value at "
  266. "%s: bitstream ended.\n", name);
  267. return AVERROR_INVALIDDATA;
  268. }
  269. if (get_bits1(gbc)) {
  270. bits[i++] = '1';
  271. ++value;
  272. } else {
  273. bits[i++] = '0';
  274. break;
  275. }
  276. }
  277. if (ctx->trace_enable) {
  278. bits[i] = 0;
  279. ff_cbs_trace_syntax_element(ctx, position,
  280. name, NULL, bits, value);
  281. }
  282. *write_to = value;
  283. return 0;
  284. }
  285. static int cbs_av1_write_increment(CodedBitstreamContext *ctx, PutBitContext *pbc,
  286. uint32_t range_min, uint32_t range_max,
  287. const char *name, uint32_t value)
  288. {
  289. int len;
  290. av_assert0(range_min <= range_max && range_max - range_min < 32);
  291. if (value < range_min || value > range_max) {
  292. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  293. "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
  294. name, value, range_min, range_max);
  295. return AVERROR_INVALIDDATA;
  296. }
  297. if (value == range_max)
  298. len = range_max - range_min;
  299. else
  300. len = value - range_min + 1;
  301. if (put_bits_left(pbc) < len)
  302. return AVERROR(ENOSPC);
  303. if (ctx->trace_enable) {
  304. char bits[33];
  305. int i;
  306. for (i = 0; i < len; i++) {
  307. if (range_min + i == value)
  308. bits[i] = '0';
  309. else
  310. bits[i] = '1';
  311. }
  312. bits[i] = 0;
  313. ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
  314. name, NULL, bits, value);
  315. }
  316. if (len > 0)
  317. put_bits(pbc, len, (1 << len) - 1 - (value != range_max));
  318. return 0;
  319. }
  320. static int cbs_av1_read_subexp(CodedBitstreamContext *ctx, GetBitContext *gbc,
  321. uint32_t range_max, const char *name,
  322. const int *subscripts, uint32_t *write_to)
  323. {
  324. uint32_t value;
  325. int position, err;
  326. uint32_t max_len, len, range_offset, range_bits;
  327. if (ctx->trace_enable)
  328. position = get_bits_count(gbc);
  329. av_assert0(range_max > 0);
  330. max_len = av_log2(range_max - 1) - 3;
  331. err = cbs_av1_read_increment(ctx, gbc, 0, max_len,
  332. "subexp_more_bits", &len);
  333. if (err < 0)
  334. return err;
  335. if (len) {
  336. range_bits = 2 + len;
  337. range_offset = 1 << range_bits;
  338. } else {
  339. range_bits = 3;
  340. range_offset = 0;
  341. }
  342. if (len < max_len) {
  343. err = ff_cbs_read_unsigned(ctx, gbc, range_bits,
  344. "subexp_bits", NULL, &value,
  345. 0, MAX_UINT_BITS(range_bits));
  346. if (err < 0)
  347. return err;
  348. } else {
  349. err = cbs_av1_read_ns(ctx, gbc, range_max - range_offset,
  350. "subexp_final_bits", NULL, &value);
  351. if (err < 0)
  352. return err;
  353. }
  354. value += range_offset;
  355. if (ctx->trace_enable)
  356. ff_cbs_trace_syntax_element(ctx, position,
  357. name, subscripts, "", value);
  358. *write_to = value;
  359. return err;
  360. }
  361. static int cbs_av1_write_subexp(CodedBitstreamContext *ctx, PutBitContext *pbc,
  362. uint32_t range_max, const char *name,
  363. const int *subscripts, uint32_t value)
  364. {
  365. int position, err;
  366. uint32_t max_len, len, range_offset, range_bits;
  367. if (value > range_max) {
  368. av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
  369. "%"PRIu32", but must be in [0,%"PRIu32"].\n",
  370. name, value, range_max);
  371. return AVERROR_INVALIDDATA;
  372. }
  373. if (ctx->trace_enable)
  374. position = put_bits_count(pbc);
  375. av_assert0(range_max > 0);
  376. max_len = av_log2(range_max - 1) - 3;
  377. if (value < 8) {
  378. range_bits = 3;
  379. range_offset = 0;
  380. len = 0;
  381. } else {
  382. range_bits = av_log2(value);
  383. len = range_bits - 2;
  384. if (len > max_len) {
  385. // The top bin is combined with the one below it.
  386. av_assert0(len == max_len + 1);
  387. --range_bits;
  388. len = max_len;
  389. }
  390. range_offset = 1 << range_bits;
  391. }
  392. err = cbs_av1_write_increment(ctx, pbc, 0, max_len,
  393. "subexp_more_bits", len);
  394. if (err < 0)
  395. return err;
  396. if (len < max_len) {
  397. err = ff_cbs_write_unsigned(ctx, pbc, range_bits,
  398. "subexp_bits", NULL,
  399. value - range_offset,
  400. 0, MAX_UINT_BITS(range_bits));
  401. if (err < 0)
  402. return err;
  403. } else {
  404. err = cbs_av1_write_ns(ctx, pbc, range_max - range_offset,
  405. "subexp_final_bits", NULL,
  406. value - range_offset);
  407. if (err < 0)
  408. return err;
  409. }
  410. if (ctx->trace_enable)
  411. ff_cbs_trace_syntax_element(ctx, position,
  412. name, subscripts, "", value);
  413. return err;
  414. }
  415. static int cbs_av1_tile_log2(int blksize, int target)
  416. {
  417. int k;
  418. for (k = 0; (blksize << k) < target; k++);
  419. return k;
  420. }
  421. static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
  422. unsigned int a, unsigned int b)
  423. {
  424. unsigned int diff, m;
  425. if (!seq->enable_order_hint)
  426. return 0;
  427. diff = a - b;
  428. m = 1 << seq->order_hint_bits_minus_1;
  429. diff = (diff & (m - 1)) - (diff & m);
  430. return diff;
  431. }
  432. static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
  433. {
  434. GetBitContext tmp = *gbc;
  435. size_t size = 0;
  436. for (int i = 0; get_bits_left(&tmp) >= 8; i++) {
  437. if (get_bits(&tmp, 8))
  438. size = i;
  439. }
  440. return size;
  441. }
  442. #define HEADER(name) do { \
  443. ff_cbs_trace_header(ctx, name); \
  444. } while (0)
  445. #define CHECK(call) do { \
  446. err = (call); \
  447. if (err < 0) \
  448. return err; \
  449. } while (0)
  450. #define FUNC_NAME(rw, codec, name) cbs_ ## codec ## _ ## rw ## _ ## name
  451. #define FUNC_AV1(rw, name) FUNC_NAME(rw, av1, name)
  452. #define FUNC(name) FUNC_AV1(READWRITE, name)
  453. #define SUBSCRIPTS(subs, ...) (subs > 0 ? ((int[subs + 1]){ subs, __VA_ARGS__ }) : NULL)
  454. #define fb(width, name) \
  455. xf(width, name, current->name, 0, MAX_UINT_BITS(width), 0)
  456. #define fc(width, name, range_min, range_max) \
  457. xf(width, name, current->name, range_min, range_max, 0)
  458. #define flag(name) fb(1, name)
  459. #define su(width, name) \
  460. xsu(width, name, current->name, 0)
  461. #define fbs(width, name, subs, ...) \
  462. xf(width, name, current->name, 0, MAX_UINT_BITS(width), subs, __VA_ARGS__)
  463. #define fcs(width, name, range_min, range_max, subs, ...) \
  464. xf(width, name, current->name, range_min, range_max, subs, __VA_ARGS__)
  465. #define flags(name, subs, ...) \
  466. xf(1, name, current->name, 0, 1, subs, __VA_ARGS__)
  467. #define sus(width, name, subs, ...) \
  468. xsu(width, name, current->name, subs, __VA_ARGS__)
  469. #define fixed(width, name, value) do { \
  470. av_unused uint32_t fixed_value = value; \
  471. xf(width, name, fixed_value, value, value, 0); \
  472. } while (0)
  473. #define READ
  474. #define READWRITE read
  475. #define RWContext GetBitContext
  476. #define xf(width, name, var, range_min, range_max, subs, ...) do { \
  477. uint32_t value; \
  478. CHECK(ff_cbs_read_unsigned(ctx, rw, width, #name, \
  479. SUBSCRIPTS(subs, __VA_ARGS__), \
  480. &value, range_min, range_max)); \
  481. var = value; \
  482. } while (0)
  483. #define xsu(width, name, var, subs, ...) do { \
  484. int32_t value; \
  485. CHECK(ff_cbs_read_signed(ctx, rw, width, #name, \
  486. SUBSCRIPTS(subs, __VA_ARGS__), &value, \
  487. MIN_INT_BITS(width), \
  488. MAX_INT_BITS(width))); \
  489. var = value; \
  490. } while (0)
  491. #define uvlc(name, range_min, range_max) do { \
  492. uint32_t value; \
  493. CHECK(cbs_av1_read_uvlc(ctx, rw, #name, \
  494. &value, range_min, range_max)); \
  495. current->name = value; \
  496. } while (0)
  497. #define ns(max_value, name, subs, ...) do { \
  498. uint32_t value; \
  499. CHECK(cbs_av1_read_ns(ctx, rw, max_value, #name, \
  500. SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
  501. current->name = value; \
  502. } while (0)
  503. #define increment(name, min, max) do { \
  504. uint32_t value; \
  505. CHECK(cbs_av1_read_increment(ctx, rw, min, max, #name, &value)); \
  506. current->name = value; \
  507. } while (0)
  508. #define subexp(name, max, subs, ...) do { \
  509. uint32_t value; \
  510. CHECK(cbs_av1_read_subexp(ctx, rw, max, #name, \
  511. SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
  512. current->name = value; \
  513. } while (0)
  514. #define delta_q(name) do { \
  515. uint8_t delta_coded; \
  516. int8_t delta_q; \
  517. xf(1, name.delta_coded, delta_coded, 0, 1, 0); \
  518. if (delta_coded) \
  519. xsu(1 + 6, name.delta_q, delta_q, 0); \
  520. else \
  521. delta_q = 0; \
  522. current->name = delta_q; \
  523. } while (0)
  524. #define leb128(name) do { \
  525. uint64_t value; \
  526. CHECK(cbs_av1_read_leb128(ctx, rw, #name, &value)); \
  527. current->name = value; \
  528. } while (0)
  529. #define infer(name, value) do { \
  530. current->name = value; \
  531. } while (0)
  532. #define byte_alignment(rw) (get_bits_count(rw) % 8)
  533. #include "cbs_av1_syntax_template.c"
  534. #undef READ
  535. #undef READWRITE
  536. #undef RWContext
  537. #undef xf
  538. #undef xsu
  539. #undef uvlc
  540. #undef ns
  541. #undef increment
  542. #undef subexp
  543. #undef delta_q
  544. #undef leb128
  545. #undef infer
  546. #undef byte_alignment
  547. #define WRITE
  548. #define READWRITE write
  549. #define RWContext PutBitContext
  550. #define xf(width, name, var, range_min, range_max, subs, ...) do { \
  551. CHECK(ff_cbs_write_unsigned(ctx, rw, width, #name, \
  552. SUBSCRIPTS(subs, __VA_ARGS__), \
  553. var, range_min, range_max)); \
  554. } while (0)
  555. #define xsu(width, name, var, subs, ...) do { \
  556. CHECK(ff_cbs_write_signed(ctx, rw, width, #name, \
  557. SUBSCRIPTS(subs, __VA_ARGS__), var, \
  558. MIN_INT_BITS(width), \
  559. MAX_INT_BITS(width))); \
  560. } while (0)
  561. #define uvlc(name, range_min, range_max) do { \
  562. CHECK(cbs_av1_write_uvlc(ctx, rw, #name, current->name, \
  563. range_min, range_max)); \
  564. } while (0)
  565. #define ns(max_value, name, subs, ...) do { \
  566. CHECK(cbs_av1_write_ns(ctx, rw, max_value, #name, \
  567. SUBSCRIPTS(subs, __VA_ARGS__), \
  568. current->name)); \
  569. } while (0)
  570. #define increment(name, min, max) do { \
  571. CHECK(cbs_av1_write_increment(ctx, rw, min, max, #name, \
  572. current->name)); \
  573. } while (0)
  574. #define subexp(name, max, subs, ...) do { \
  575. CHECK(cbs_av1_write_subexp(ctx, rw, max, #name, \
  576. SUBSCRIPTS(subs, __VA_ARGS__), \
  577. current->name)); \
  578. } while (0)
  579. #define delta_q(name) do { \
  580. xf(1, name.delta_coded, current->name != 0, 0, 1, 0); \
  581. if (current->name) \
  582. xsu(1 + 6, name.delta_q, current->name, 0); \
  583. } while (0)
  584. #define leb128(name) do { \
  585. CHECK(cbs_av1_write_leb128(ctx, rw, #name, current->name)); \
  586. } while (0)
  587. #define infer(name, value) do { \
  588. if (current->name != (value)) { \
  589. av_log(ctx->log_ctx, AV_LOG_WARNING, "Warning: " \
  590. "%s does not match inferred value: " \
  591. "%"PRId64", but should be %"PRId64".\n", \
  592. #name, (int64_t)current->name, (int64_t)(value)); \
  593. } \
  594. } while (0)
  595. #define byte_alignment(rw) (put_bits_count(rw) % 8)
  596. #include "cbs_av1_syntax_template.c"
  597. #undef WRITE
  598. #undef READWRITE
  599. #undef RWContext
  600. #undef xf
  601. #undef xsu
  602. #undef uvlc
  603. #undef ns
  604. #undef increment
  605. #undef subexp
  606. #undef delta_q
  607. #undef leb128
  608. #undef infer
  609. #undef byte_alignment
  610. static int cbs_av1_split_fragment(CodedBitstreamContext *ctx,
  611. CodedBitstreamFragment *frag,
  612. int header)
  613. {
  614. GetBitContext gbc;
  615. uint8_t *data;
  616. size_t size;
  617. uint64_t obu_length;
  618. int pos, err, trace;
  619. // Don't include this parsing in trace output.
  620. trace = ctx->trace_enable;
  621. ctx->trace_enable = 0;
  622. data = frag->data;
  623. size = frag->data_size;
  624. if (INT_MAX / 8 < size) {
  625. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid fragment: "
  626. "too large (%"SIZE_SPECIFIER" bytes).\n", size);
  627. err = AVERROR_INVALIDDATA;
  628. goto fail;
  629. }
  630. while (size > 0) {
  631. AV1RawOBUHeader header;
  632. uint64_t obu_size;
  633. init_get_bits(&gbc, data, 8 * size);
  634. err = cbs_av1_read_obu_header(ctx, &gbc, &header);
  635. if (err < 0)
  636. goto fail;
  637. if (get_bits_left(&gbc) < 8) {
  638. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU: fragment "
  639. "too short (%"SIZE_SPECIFIER" bytes).\n", size);
  640. err = AVERROR_INVALIDDATA;
  641. goto fail;
  642. }
  643. if (header.obu_has_size_field) {
  644. err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
  645. if (err < 0)
  646. goto fail;
  647. } else
  648. obu_size = size - 1 - header.obu_extension_flag;
  649. pos = get_bits_count(&gbc);
  650. av_assert0(pos % 8 == 0 && pos / 8 <= size);
  651. obu_length = pos / 8 + obu_size;
  652. if (size < obu_length) {
  653. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
  654. "%"PRIu64", but only %"SIZE_SPECIFIER" bytes remaining in fragment.\n",
  655. obu_length, size);
  656. err = AVERROR_INVALIDDATA;
  657. goto fail;
  658. }
  659. err = ff_cbs_insert_unit_data(ctx, frag, -1, header.obu_type,
  660. data, obu_length, frag->data_ref);
  661. if (err < 0)
  662. goto fail;
  663. data += obu_length;
  664. size -= obu_length;
  665. }
  666. err = 0;
  667. fail:
  668. ctx->trace_enable = trace;
  669. return err;
  670. }
  671. static void cbs_av1_free_tile_data(AV1RawTileData *td)
  672. {
  673. av_buffer_unref(&td->data_ref);
  674. }
  675. static void cbs_av1_free_padding(AV1RawPadding *pd)
  676. {
  677. av_buffer_unref(&pd->payload_ref);
  678. }
  679. static void cbs_av1_free_metadata(AV1RawMetadata *md)
  680. {
  681. switch (md->metadata_type) {
  682. case AV1_METADATA_TYPE_ITUT_T35:
  683. av_buffer_unref(&md->metadata.itut_t35.payload_ref);
  684. break;
  685. }
  686. }
  687. static void cbs_av1_free_obu(void *unit, uint8_t *content)
  688. {
  689. AV1RawOBU *obu = (AV1RawOBU*)content;
  690. switch (obu->header.obu_type) {
  691. case AV1_OBU_TILE_GROUP:
  692. cbs_av1_free_tile_data(&obu->obu.tile_group.tile_data);
  693. break;
  694. case AV1_OBU_FRAME:
  695. cbs_av1_free_tile_data(&obu->obu.frame.tile_group.tile_data);
  696. break;
  697. case AV1_OBU_TILE_LIST:
  698. cbs_av1_free_tile_data(&obu->obu.tile_list.tile_data);
  699. break;
  700. case AV1_OBU_METADATA:
  701. cbs_av1_free_metadata(&obu->obu.metadata);
  702. break;
  703. case AV1_OBU_PADDING:
  704. cbs_av1_free_padding(&obu->obu.padding);
  705. break;
  706. }
  707. av_freep(&obu);
  708. }
  709. static int cbs_av1_ref_tile_data(CodedBitstreamContext *ctx,
  710. CodedBitstreamUnit *unit,
  711. GetBitContext *gbc,
  712. AV1RawTileData *td)
  713. {
  714. int pos;
  715. pos = get_bits_count(gbc);
  716. if (pos >= 8 * unit->data_size) {
  717. av_log(ctx->log_ctx, AV_LOG_ERROR, "Bitstream ended before "
  718. "any data in tile group (%d bits read).\n", pos);
  719. return AVERROR_INVALIDDATA;
  720. }
  721. // Must be byte-aligned at this point.
  722. av_assert0(pos % 8 == 0);
  723. td->data_ref = av_buffer_ref(unit->data_ref);
  724. if (!td->data_ref)
  725. return AVERROR(ENOMEM);
  726. td->data = unit->data + pos / 8;
  727. td->data_size = unit->data_size - pos / 8;
  728. return 0;
  729. }
  730. static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
  731. CodedBitstreamUnit *unit)
  732. {
  733. CodedBitstreamAV1Context *priv = ctx->priv_data;
  734. AV1RawOBU *obu;
  735. GetBitContext gbc;
  736. int err, start_pos, end_pos;
  737. err = ff_cbs_alloc_unit_content(ctx, unit, sizeof(*obu),
  738. &cbs_av1_free_obu);
  739. if (err < 0)
  740. return err;
  741. obu = unit->content;
  742. err = init_get_bits(&gbc, unit->data, 8 * unit->data_size);
  743. if (err < 0)
  744. return err;
  745. err = cbs_av1_read_obu_header(ctx, &gbc, &obu->header);
  746. if (err < 0)
  747. return err;
  748. av_assert0(obu->header.obu_type == unit->type);
  749. if (obu->header.obu_has_size_field) {
  750. uint64_t obu_size;
  751. err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
  752. if (err < 0)
  753. return err;
  754. obu->obu_size = obu_size;
  755. } else {
  756. if (unit->data_size < 1 + obu->header.obu_extension_flag) {
  757. av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
  758. "unit too short (%"SIZE_SPECIFIER").\n", unit->data_size);
  759. return AVERROR_INVALIDDATA;
  760. }
  761. obu->obu_size = unit->data_size - 1 - obu->header.obu_extension_flag;
  762. }
  763. start_pos = get_bits_count(&gbc);
  764. if (obu->header.obu_extension_flag) {
  765. priv->temporal_id = obu->header.temporal_id;
  766. priv->spatial_id = obu->header.spatial_id;
  767. if (obu->header.obu_type != AV1_OBU_SEQUENCE_HEADER &&
  768. obu->header.obu_type != AV1_OBU_TEMPORAL_DELIMITER &&
  769. priv->operating_point_idc) {
  770. int in_temporal_layer =
  771. (priv->operating_point_idc >> priv->temporal_id ) & 1;
  772. int in_spatial_layer =
  773. (priv->operating_point_idc >> (priv->spatial_id + 8)) & 1;
  774. if (!in_temporal_layer || !in_spatial_layer) {
  775. // Decoding will drop this OBU at this operating point.
  776. }
  777. }
  778. } else {
  779. priv->temporal_id = 0;
  780. priv->spatial_id = 0;
  781. }
  782. switch (obu->header.obu_type) {
  783. case AV1_OBU_SEQUENCE_HEADER:
  784. {
  785. err = cbs_av1_read_sequence_header_obu(ctx, &gbc,
  786. &obu->obu.sequence_header);
  787. if (err < 0)
  788. return err;
  789. av_buffer_unref(&priv->sequence_header_ref);
  790. priv->sequence_header = NULL;
  791. priv->sequence_header_ref = av_buffer_ref(unit->content_ref);
  792. if (!priv->sequence_header_ref)
  793. return AVERROR(ENOMEM);
  794. priv->sequence_header = &obu->obu.sequence_header;
  795. }
  796. break;
  797. case AV1_OBU_TEMPORAL_DELIMITER:
  798. {
  799. err = cbs_av1_read_temporal_delimiter_obu(ctx, &gbc);
  800. if (err < 0)
  801. return err;
  802. }
  803. break;
  804. case AV1_OBU_FRAME_HEADER:
  805. case AV1_OBU_REDUNDANT_FRAME_HEADER:
  806. {
  807. err = cbs_av1_read_frame_header_obu(ctx, &gbc,
  808. &obu->obu.frame_header,
  809. obu->header.obu_type ==
  810. AV1_OBU_REDUNDANT_FRAME_HEADER,
  811. unit->data_ref);
  812. if (err < 0)
  813. return err;
  814. }
  815. break;
  816. case AV1_OBU_TILE_GROUP:
  817. {
  818. err = cbs_av1_read_tile_group_obu(ctx, &gbc,
  819. &obu->obu.tile_group);
  820. if (err < 0)
  821. return err;
  822. err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
  823. &obu->obu.tile_group.tile_data);
  824. if (err < 0)
  825. return err;
  826. }
  827. break;
  828. case AV1_OBU_FRAME:
  829. {
  830. err = cbs_av1_read_frame_obu(ctx, &gbc, &obu->obu.frame,
  831. unit->data_ref);
  832. if (err < 0)
  833. return err;
  834. err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
  835. &obu->obu.frame.tile_group.tile_data);
  836. if (err < 0)
  837. return err;
  838. }
  839. break;
  840. case AV1_OBU_TILE_LIST:
  841. {
  842. err = cbs_av1_read_tile_list_obu(ctx, &gbc,
  843. &obu->obu.tile_list);
  844. if (err < 0)
  845. return err;
  846. err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
  847. &obu->obu.tile_list.tile_data);
  848. if (err < 0)
  849. return err;
  850. }
  851. break;
  852. case AV1_OBU_METADATA:
  853. {
  854. err = cbs_av1_read_metadata_obu(ctx, &gbc, &obu->obu.metadata);
  855. if (err < 0)
  856. return err;
  857. }
  858. break;
  859. case AV1_OBU_PADDING:
  860. {
  861. err = cbs_av1_read_padding_obu(ctx, &gbc, &obu->obu.padding);
  862. if (err < 0)
  863. return err;
  864. }
  865. break;
  866. default:
  867. return AVERROR(ENOSYS);
  868. }
  869. end_pos = get_bits_count(&gbc);
  870. av_assert0(end_pos <= unit->data_size * 8);
  871. if (obu->obu_size > 0 &&
  872. obu->header.obu_type != AV1_OBU_TILE_GROUP &&
  873. obu->header.obu_type != AV1_OBU_FRAME) {
  874. int nb_bits = obu->obu_size * 8 + start_pos - end_pos;
  875. if (nb_bits <= 0)
  876. return AVERROR_INVALIDDATA;
  877. err = cbs_av1_read_trailing_bits(ctx, &gbc, nb_bits);
  878. if (err < 0)
  879. return err;
  880. }
  881. return 0;
  882. }
  883. static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
  884. CodedBitstreamUnit *unit,
  885. PutBitContext *pbc)
  886. {
  887. CodedBitstreamAV1Context *priv = ctx->priv_data;
  888. AV1RawOBU *obu = unit->content;
  889. PutBitContext pbc_tmp;
  890. AV1RawTileData *td;
  891. size_t header_size;
  892. int err, start_pos, end_pos, data_pos;
  893. // OBUs in the normal bitstream format must contain a size field
  894. // in every OBU (in annex B it is optional, but we don't support
  895. // writing that).
  896. obu->header.obu_has_size_field = 1;
  897. err = cbs_av1_write_obu_header(ctx, pbc, &obu->header);
  898. if (err < 0)
  899. return err;
  900. if (obu->header.obu_has_size_field) {
  901. pbc_tmp = *pbc;
  902. // Add space for the size field to fill later.
  903. put_bits32(pbc, 0);
  904. put_bits32(pbc, 0);
  905. }
  906. td = NULL;
  907. start_pos = put_bits_count(pbc);
  908. switch (obu->header.obu_type) {
  909. case AV1_OBU_SEQUENCE_HEADER:
  910. {
  911. err = cbs_av1_write_sequence_header_obu(ctx, pbc,
  912. &obu->obu.sequence_header);
  913. if (err < 0)
  914. return err;
  915. av_buffer_unref(&priv->sequence_header_ref);
  916. priv->sequence_header = NULL;
  917. priv->sequence_header_ref = av_buffer_ref(unit->content_ref);
  918. if (!priv->sequence_header_ref)
  919. return AVERROR(ENOMEM);
  920. priv->sequence_header = &obu->obu.sequence_header;
  921. }
  922. break;
  923. case AV1_OBU_TEMPORAL_DELIMITER:
  924. {
  925. err = cbs_av1_write_temporal_delimiter_obu(ctx, pbc);
  926. if (err < 0)
  927. return err;
  928. }
  929. break;
  930. case AV1_OBU_FRAME_HEADER:
  931. case AV1_OBU_REDUNDANT_FRAME_HEADER:
  932. {
  933. err = cbs_av1_write_frame_header_obu(ctx, pbc,
  934. &obu->obu.frame_header,
  935. obu->header.obu_type ==
  936. AV1_OBU_REDUNDANT_FRAME_HEADER,
  937. NULL);
  938. if (err < 0)
  939. return err;
  940. }
  941. break;
  942. case AV1_OBU_TILE_GROUP:
  943. {
  944. err = cbs_av1_write_tile_group_obu(ctx, pbc,
  945. &obu->obu.tile_group);
  946. if (err < 0)
  947. return err;
  948. td = &obu->obu.tile_group.tile_data;
  949. }
  950. break;
  951. case AV1_OBU_FRAME:
  952. {
  953. err = cbs_av1_write_frame_obu(ctx, pbc, &obu->obu.frame, NULL);
  954. if (err < 0)
  955. return err;
  956. td = &obu->obu.frame.tile_group.tile_data;
  957. }
  958. break;
  959. case AV1_OBU_TILE_LIST:
  960. {
  961. err = cbs_av1_write_tile_list_obu(ctx, pbc, &obu->obu.tile_list);
  962. if (err < 0)
  963. return err;
  964. td = &obu->obu.tile_list.tile_data;
  965. }
  966. break;
  967. case AV1_OBU_METADATA:
  968. {
  969. err = cbs_av1_write_metadata_obu(ctx, pbc, &obu->obu.metadata);
  970. if (err < 0)
  971. return err;
  972. }
  973. break;
  974. case AV1_OBU_PADDING:
  975. {
  976. err = cbs_av1_write_padding_obu(ctx, pbc, &obu->obu.padding);
  977. if (err < 0)
  978. return err;
  979. }
  980. break;
  981. default:
  982. return AVERROR(ENOSYS);
  983. }
  984. end_pos = put_bits_count(pbc);
  985. header_size = (end_pos - start_pos + 7) / 8;
  986. if (td) {
  987. obu->obu_size = header_size + td->data_size;
  988. } else if (header_size > 0) {
  989. // Add trailing bits and recalculate.
  990. err = cbs_av1_write_trailing_bits(ctx, pbc, 8 - end_pos % 8);
  991. if (err < 0)
  992. return err;
  993. end_pos = put_bits_count(pbc);
  994. obu->obu_size = header_size = (end_pos - start_pos + 7) / 8;
  995. } else {
  996. // Empty OBU.
  997. obu->obu_size = 0;
  998. }
  999. end_pos = put_bits_count(pbc);
  1000. // Must now be byte-aligned.
  1001. av_assert0(end_pos % 8 == 0);
  1002. flush_put_bits(pbc);
  1003. start_pos /= 8;
  1004. end_pos /= 8;
  1005. *pbc = pbc_tmp;
  1006. err = cbs_av1_write_leb128(ctx, pbc, "obu_size", obu->obu_size);
  1007. if (err < 0)
  1008. return err;
  1009. data_pos = put_bits_count(pbc) / 8;
  1010. flush_put_bits(pbc);
  1011. av_assert0(data_pos <= start_pos);
  1012. if (8 * obu->obu_size > put_bits_left(pbc))
  1013. return AVERROR(ENOSPC);
  1014. if (obu->obu_size > 0) {
  1015. memmove(priv->write_buffer + data_pos,
  1016. priv->write_buffer + start_pos, header_size);
  1017. skip_put_bytes(pbc, header_size);
  1018. if (td) {
  1019. memcpy(priv->write_buffer + data_pos + header_size,
  1020. td->data, td->data_size);
  1021. skip_put_bytes(pbc, td->data_size);
  1022. }
  1023. }
  1024. return 0;
  1025. }
  1026. static int cbs_av1_write_unit(CodedBitstreamContext *ctx,
  1027. CodedBitstreamUnit *unit)
  1028. {
  1029. CodedBitstreamAV1Context *priv = ctx->priv_data;
  1030. PutBitContext pbc;
  1031. int err;
  1032. if (!priv->write_buffer) {
  1033. // Initial write buffer size is 1MB.
  1034. priv->write_buffer_size = 1024 * 1024;
  1035. reallocate_and_try_again:
  1036. err = av_reallocp(&priv->write_buffer, priv->write_buffer_size);
  1037. if (err < 0) {
  1038. av_log(ctx->log_ctx, AV_LOG_ERROR, "Unable to allocate a "
  1039. "sufficiently large write buffer (last attempt "
  1040. "%"SIZE_SPECIFIER" bytes).\n", priv->write_buffer_size);
  1041. return err;
  1042. }
  1043. }
  1044. init_put_bits(&pbc, priv->write_buffer, priv->write_buffer_size);
  1045. err = cbs_av1_write_obu(ctx, unit, &pbc);
  1046. if (err == AVERROR(ENOSPC)) {
  1047. // Overflow.
  1048. priv->write_buffer_size *= 2;
  1049. goto reallocate_and_try_again;
  1050. }
  1051. if (err < 0)
  1052. return err;
  1053. // Overflow but we didn't notice.
  1054. av_assert0(put_bits_count(&pbc) <= 8 * priv->write_buffer_size);
  1055. // OBU data must be byte-aligned.
  1056. av_assert0(put_bits_count(&pbc) % 8 == 0);
  1057. unit->data_size = put_bits_count(&pbc) / 8;
  1058. flush_put_bits(&pbc);
  1059. err = ff_cbs_alloc_unit_data(ctx, unit, unit->data_size);
  1060. if (err < 0)
  1061. return err;
  1062. memcpy(unit->data, priv->write_buffer, unit->data_size);
  1063. return 0;
  1064. }
  1065. static int cbs_av1_assemble_fragment(CodedBitstreamContext *ctx,
  1066. CodedBitstreamFragment *frag)
  1067. {
  1068. size_t size, pos;
  1069. int i;
  1070. size = 0;
  1071. for (i = 0; i < frag->nb_units; i++)
  1072. size += frag->units[i].data_size;
  1073. frag->data_ref = av_buffer_alloc(size + AV_INPUT_BUFFER_PADDING_SIZE);
  1074. if (!frag->data_ref)
  1075. return AVERROR(ENOMEM);
  1076. frag->data = frag->data_ref->data;
  1077. memset(frag->data + size, 0, AV_INPUT_BUFFER_PADDING_SIZE);
  1078. pos = 0;
  1079. for (i = 0; i < frag->nb_units; i++) {
  1080. memcpy(frag->data + pos, frag->units[i].data,
  1081. frag->units[i].data_size);
  1082. pos += frag->units[i].data_size;
  1083. }
  1084. av_assert0(pos == size);
  1085. frag->data_size = size;
  1086. return 0;
  1087. }
  1088. static void cbs_av1_close(CodedBitstreamContext *ctx)
  1089. {
  1090. CodedBitstreamAV1Context *priv = ctx->priv_data;
  1091. av_buffer_unref(&priv->sequence_header_ref);
  1092. av_buffer_unref(&priv->frame_header_ref);
  1093. av_freep(&priv->write_buffer);
  1094. }
  1095. const CodedBitstreamType ff_cbs_type_av1 = {
  1096. .codec_id = AV_CODEC_ID_AV1,
  1097. .priv_data_size = sizeof(CodedBitstreamAV1Context),
  1098. .split_fragment = &cbs_av1_split_fragment,
  1099. .read_unit = &cbs_av1_read_unit,
  1100. .write_unit = &cbs_av1_write_unit,
  1101. .assemble_fragment = &cbs_av1_assemble_fragment,
  1102. .close = &cbs_av1_close,
  1103. };