You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

645 lines
17KB

  1. #include "StochasticGrammar.h"
  2. #include <random>
  3. #if 0
  4. // eventually get rid of this global random generator
  5. std::default_random_engine generator{57};
  6. std::uniform_real_distribution<float> distribution{0, 1.0};
  7. float Random::get()
  8. {
  9. return distribution(generator);
  10. }
  11. #endif
  12. /***************************************************************************************************************
  13. *
  14. * ProductionRuleKeys
  15. *
  16. ***************************************************************************************************************/
  17. void ProductionRuleKeys::breakDown(GKEY key, GKEY * outKeys)
  18. {
  19. switch (key) {
  20. // terminal keys expand to themselves
  21. case sg_w2:
  22. case sg_w:
  23. case sg_h:
  24. case sg_q:
  25. case sg_e:
  26. case sg_e3:
  27. case sg_sx:
  28. case sg_68:
  29. case sg_78:
  30. case sg_98:
  31. case sg_dq:
  32. case sg_dh:
  33. case sg_de:
  34. //case sg_hdq:
  35. //case sg_qhe:
  36. *outKeys++ = key;
  37. *outKeys++ = sg_invalid;
  38. break;
  39. case sg_798:
  40. *outKeys++ = sg_78;
  41. *outKeys++ = sg_98;
  42. *outKeys++ = sg_invalid;
  43. break;
  44. case sg_ww:
  45. *outKeys++ = sg_w;
  46. *outKeys++ = sg_w;
  47. *outKeys++ = sg_invalid;
  48. break;
  49. case sg_hh:
  50. *outKeys++ = sg_h;
  51. *outKeys++ = sg_h;
  52. *outKeys++ = sg_invalid;
  53. break;
  54. case sg_qq:
  55. *outKeys++ = sg_q;
  56. *outKeys++ = sg_q;
  57. *outKeys++ = sg_invalid;
  58. break;
  59. case sg_sxsx:
  60. *outKeys++ = sg_sx;
  61. *outKeys++ = sg_sx;
  62. *outKeys++ = sg_invalid;
  63. break;
  64. case sg_ee:
  65. *outKeys++ = sg_e;
  66. *outKeys++ = sg_e;
  67. *outKeys++ = sg_invalid;
  68. break;
  69. case sg_e3e3e3:
  70. *outKeys++ = sg_e3;
  71. *outKeys++ = sg_e3;
  72. *outKeys++ = sg_e3;
  73. *outKeys++ = sg_invalid;
  74. break;
  75. case sg_hdq:
  76. *outKeys++ = sg_h;
  77. *outKeys++ = sg_dq;
  78. *outKeys++ = sg_invalid;
  79. break;
  80. case sg_hq:
  81. *outKeys++ = sg_h;
  82. *outKeys++ = sg_q;
  83. *outKeys++ = sg_invalid;
  84. break;
  85. case sg_qh:
  86. *outKeys++ = sg_q;
  87. *outKeys++ = sg_h;
  88. *outKeys++ = sg_invalid;
  89. break;
  90. case sg_qhe:
  91. *outKeys++ = sg_q;
  92. *outKeys++ = sg_h;
  93. *outKeys++ = sg_e;
  94. *outKeys++ = sg_invalid;
  95. break;
  96. case sg_q78:
  97. *outKeys++ = sg_q;
  98. *outKeys++ = sg_78;
  99. *outKeys++ = sg_invalid;
  100. break;
  101. case sg_qe68:
  102. *outKeys++ = sg_q;
  103. *outKeys++ = sg_e;
  104. *outKeys++ = sg_68;
  105. *outKeys++ = sg_invalid;
  106. break;
  107. default:
  108. assert(false);
  109. }
  110. }
  111. const char * ProductionRuleKeys::toString(GKEY key)
  112. {
  113. const char * ret;
  114. switch (key) {
  115. case sg_w2: ret = "2xw"; break;
  116. case sg_ww: ret = "w,w"; break;
  117. case sg_w: ret = "w"; break;
  118. case sg_h: ret = "h"; break;
  119. case sg_hh: ret = "h,h"; break;
  120. case sg_q: ret = "q"; break;
  121. case sg_qq: ret = "q,q"; break;
  122. case sg_e: ret = "e"; break;
  123. case sg_ee: ret = "e,e"; break;
  124. case sg_e3e3e3: ret = "3e,3e,3e"; break;
  125. case sg_e3: ret = "3e"; break;
  126. case sg_sx: ret = "sx"; break;
  127. case sg_sxsx: ret = "sx, sx"; break;
  128. case sg_68: ret = "<6/8>"; break;
  129. case sg_78: ret = "<7/8>"; break;
  130. case sg_98: ret = "<9/8>"; break;
  131. case sg_798: ret = "7+9/8"; break;
  132. case sg_dq: ret = "q."; break;
  133. case sg_dh: ret = "h."; break;
  134. case sg_de: ret = "e."; break;
  135. case sg_hdq: ret = "h,q."; break;
  136. case sg_qhe: ret = "q,h,e"; break;
  137. case sg_qh: ret = "q,h"; break;
  138. case sg_hq: ret = "h,q"; break;
  139. case sg_q78: ret = "q,<7/8>"; break;
  140. case sg_qe68: ret = "q,e,<6/8>"; break;
  141. default:
  142. printf("can't print key %d\n", key);
  143. assert(false);
  144. ret = "error";
  145. }
  146. return ret;
  147. }
  148. int ProductionRuleKeys::getDuration(GKEY key)
  149. {
  150. int ret;
  151. assert((PPQ % 3) == 0);
  152. switch (key) {
  153. case sg_798:
  154. case sg_w2: ret = 2 * 4 * PPQ; break;
  155. case sg_ww: ret = 2 * 4 * PPQ; break;
  156. case sg_w: ret = 4 * PPQ; break;
  157. case sg_h: ret = 2 * PPQ; break;
  158. case sg_hh: ret = 2 * 2 * PPQ; break;
  159. case sg_q: ret = 1 * PPQ; break;
  160. case sg_qq: ret = 2 * PPQ; break;
  161. case sg_e:
  162. assert((PPQ % 2) == 0);
  163. ret = PPQ / 2;
  164. break;
  165. case sg_ee: ret = PPQ; break;
  166. case sg_sxsx: ret = PPQ / 2; break;
  167. case sg_sx:
  168. assert((PPQ % 4) == 0);
  169. ret = PPQ / 4;
  170. break;
  171. case sg_e3e3e3: ret = PPQ; break;
  172. case sg_e3:
  173. assert(PPQ % 3 == 0);
  174. ret = PPQ / 3;
  175. break;
  176. case sg_68: ret = 6 * (PPQ / 2); break;
  177. case sg_78: ret = 7 * (PPQ / 2); break;
  178. case sg_q78:
  179. case sg_qe68:
  180. case sg_98:
  181. ret = 9 * (PPQ / 2); break;
  182. case sg_dq: ret = 3 * PPQ / 2; break;
  183. case sg_dh: ret = 3 * PPQ; break;
  184. case sg_de: ret = 3 * PPQ / 4; break;
  185. case sg_hdq: ret = 2 * PPQ + 3 * PPQ / 2; break;
  186. case sg_qhe: ret = PPQ * 3 + PPQ / 2; break;
  187. case sg_hq:
  188. case sg_qh: ret = PPQ * 3; break;
  189. default:
  190. #ifdef _MSC_VER
  191. printf("can't get dur key %d\n", key);
  192. #endif
  193. assert(false);
  194. ret = 0;
  195. }
  196. return ret;
  197. }
  198. /***************************************************************************************************************
  199. *
  200. * ProductionRule
  201. *
  202. ***************************************************************************************************************/
  203. // generate production, return code for what happened
  204. int ProductionRule::_evaluateRule(const ProductionRule& rule, float random)
  205. {
  206. assert(random >= 0 && random <= 1);
  207. int i = 0;
  208. for (bool done2 = false; !done2; ++i) {
  209. assert(i < numEntries);
  210. //printf("prob[%d] is %d\n", i, rule.entries[i].probability);
  211. if (rule.entries[i].probability >= random) {
  212. GKEY code = rule.entries[i].code;
  213. //printf("rule fired on code abs val=%d\n", code);
  214. return code;
  215. }
  216. }
  217. assert(false); // no rule fired
  218. return 0;
  219. }
  220. void ProductionRule::evaluate(EvaluationState& es, int ruleToEval)
  221. {
  222. //printf("\n evaluate called on rule #%d\n", ruleToEval);
  223. const ProductionRule& rule = es.rules[ruleToEval];
  224. #ifdef _MSC_VER
  225. assert(rule._isValid(ruleToEval));
  226. #endif
  227. GKEY result = _evaluateRule(rule, es.r());
  228. if (result == sg_invalid) // request to terminate recursion
  229. {
  230. GKEY code = ruleToEval; // our "real" terminal code is our table index
  231. //printf("production rule #%d terminated\n", ruleToEval);
  232. //printf("rule terminated! execute code %s\n", ProductionRuleKeys::toString(code));
  233. es.writeSymbol(code);
  234. } else {
  235. //printf("production rule #%d expanded to %d\n", ruleToEval, result);
  236. // need to expand,then eval all of the expanded codes
  237. GKEY buffer[ProductionRuleKeys::bufferSize];
  238. ProductionRuleKeys::breakDown(result, buffer);
  239. for (GKEY * p = buffer; *p != sg_invalid; ++p) {
  240. //printf("expanding rule #%d with %d\n", ruleToEval, *p);
  241. evaluate(es, *p);
  242. }
  243. //printf("done expanding %d\n", ruleToEval);
  244. }
  245. }
  246. // is the data self consistent, and appropriate for index
  247. #if defined(_MSC_VER) && defined(_DEBUG)
  248. bool ProductionRule::_isValid(int index) const
  249. {
  250. if (index == sg_invalid) {
  251. printf("rule not allowed in first slot\n");
  252. return false;
  253. }
  254. if (entries[0] == ProductionRuleEntry()) {
  255. printf("rule at index %d is ininitizlied. bad graph (%s)\n",
  256. index,
  257. ProductionRuleKeys::toString(index));
  258. return false;
  259. }
  260. float last = -1;
  261. bool foundTerminator = false;
  262. for (int i = 0; !foundTerminator; ++i) {
  263. if (i >= numEntries) {
  264. printf("entries not terminated index=%d 'i' is too big: %d\n", index, i);
  265. return false;
  266. }
  267. const ProductionRuleEntry& e = entries[i];
  268. if (e.probability > 1.0f) {
  269. printf("probability %f > 1 \n", e.probability);
  270. return false;
  271. }
  272. if (e.probability == 0.f) {
  273. printf("zero probability in rule\n");
  274. return false;
  275. }
  276. if (e.probability <= last) // probabilities grow
  277. {
  278. printf("probability not growing is %f was %f\n", e.probability, last);
  279. return false;
  280. }
  281. if (e.probability == 1.0f) {
  282. foundTerminator = true; // must have a 255 to end it
  283. if (e.code == index) {
  284. printf("rule terminates on self: recursion not allowed\n");
  285. return false;
  286. }
  287. }
  288. if (e.code < sg_invalid || e.code > sg_last) {
  289. printf("rule[%d] entry[%d] had invalid code: %d\n", index, i, e.code);
  290. return false;
  291. }
  292. // if we are terminating recursion, then by definition our duration is correct
  293. if (e.code != sg_invalid) {
  294. // otherwise, make sure the entry has the right duration
  295. int entryDuration = ProductionRuleKeys::getDuration(e.code);
  296. int ruleDuration = ProductionRuleKeys::getDuration(index);
  297. if (entryDuration != ruleDuration) {
  298. printf("production rule[%d] (name %s) duration mismatch (time not conserved) rule dur = %d entry dur %d\n",
  299. index, ProductionRuleKeys::toString(index), ruleDuration, entryDuration);
  300. return false;
  301. }
  302. }
  303. last = e.probability;
  304. }
  305. return true;
  306. }
  307. #endif
  308. #ifdef _DEBUG
  309. bool ProductionRule::isGrammarValid(const ProductionRule * rules, int numRules, GKEY firstRule)
  310. {
  311. //printf("is grammar valid, numRules = %d first = %d\n", numRules, firstRule);
  312. if (firstRule < sg_first) {
  313. printf("first rule index (%d) bad\n", firstRule);
  314. return false;
  315. }
  316. if (numRules != (sg_last + 1)) {
  317. printf("bad number of rules\n");
  318. return false;
  319. }
  320. const ProductionRule& r = rules[firstRule];
  321. if (!r._isValid(firstRule)) {
  322. return false;
  323. }
  324. // now, make sure every entry goes to something real
  325. bool foundTerminator = false;
  326. for (int i = 0; !foundTerminator; ++i) {
  327. const ProductionRuleEntry& e = r.entries[i];
  328. if (e.probability == 1.0f)
  329. foundTerminator = true; // must have a 255 to end it
  330. GKEY _newKey = e.code;
  331. if (_newKey != sg_invalid) {
  332. GKEY outKeys[4];
  333. ProductionRuleKeys::breakDown(_newKey, outKeys);
  334. for (GKEY * p = outKeys; *p != sg_invalid; ++p) {
  335. if (!isGrammarValid(rules, numRules, *p)) {
  336. printf("followed rules to bad one\n");
  337. return false;
  338. }
  339. }
  340. }
  341. }
  342. return true;
  343. }
  344. #endif
  345. /*
  346. StochasticGrammarDictionary
  347. maybe move this to a test file?
  348. **/
  349. static ProductionRule _rules0[fullRuleTableSize];
  350. static ProductionRule _rules1[fullRuleTableSize];
  351. static ProductionRule _rules2[fullRuleTableSize];
  352. static ProductionRule _rules3[fullRuleTableSize];
  353. bool StochasticGrammarDictionary::_didInitRules = false;
  354. void StochasticGrammarDictionary::initRules()
  355. {
  356. initRule0(_rules0);
  357. initRule1(_rules1);
  358. initRule2(_rules2);
  359. initRule3(_rules3);
  360. }
  361. // super dumb one - makes quarter notes
  362. void StochasticGrammarDictionary::initRule0(ProductionRule * rules)
  363. {
  364. // break w2 into w,w prob 100
  365. {
  366. ProductionRule& r = rules[sg_w2];
  367. r.entries[0].probability = 1.0f;
  368. r.entries[0].code = sg_ww;
  369. }
  370. // break w into h, h
  371. {
  372. ProductionRule& r = rules[sg_w];
  373. r.entries[0].probability = 1.0f;
  374. r.entries[0].code = sg_hh;
  375. }
  376. // break h into q,q
  377. {
  378. ProductionRule&r = rules[sg_h];
  379. r.entries[0].probability = 1.0f;
  380. r.entries[0].code = sg_qq;
  381. }
  382. // stop on q
  383. rules[sg_q].makeTerminal();
  384. }
  385. void StochasticGrammarDictionary::initRule1(ProductionRule * rules)
  386. {
  387. // break w2 into w,w prob 100
  388. {
  389. ProductionRule& r = rules[sg_w2];
  390. r.entries[0].probability = 1.0f;
  391. r.entries[0].code = sg_ww;
  392. }
  393. // break w into h, h
  394. {
  395. ProductionRule& r = rules[sg_w];
  396. r.entries[0].probability = 1.0f;
  397. r.entries[0].code = sg_hh;
  398. }
  399. // break h into q,q, or h
  400. {
  401. ProductionRule&r = rules[sg_h];
  402. r.entries[0].probability = .75f;
  403. r.entries[0].code = sg_qq;
  404. r.entries[1].probability = 1.0f;
  405. r.entries[1].code = sg_invalid;
  406. }
  407. // stop on q, or make e
  408. {
  409. ProductionRule&r = rules[sg_q];
  410. r.entries[0].probability = .3f;
  411. r.entries[0].code = sg_ee;
  412. r.entries[1].probability = 1.0f;
  413. r.entries[1].code = sg_invalid;
  414. }
  415. // stop on e, or make sx
  416. {
  417. ProductionRule&r = rules[sg_e];
  418. r.entries[0].probability = .3f;
  419. r.entries[0].code = sg_sxsx;
  420. r.entries[1].probability = 1.0f;
  421. r.entries[1].code = sg_invalid;
  422. }
  423. rules[sg_sx].makeTerminal();
  424. }
  425. void StochasticGrammarDictionary::initRule2(ProductionRule * rules)
  426. {
  427. // break w2 into 7+9/8 prob 100
  428. {
  429. ProductionRule& r = rules[sg_w2];
  430. r.entries[0].probability = 1.0f;
  431. r.entries[0].code = sg_798;
  432. }
  433. // 9/8 -> different combos
  434. {
  435. ProductionRule& r = rules[sg_98];
  436. r.entries[0].probability = .5f;
  437. r.entries[0].code = sg_q78;
  438. r.entries[1].probability = 1.0f;
  439. r.entries[1].code = sg_qe68;
  440. }
  441. // 6/8 ->
  442. {
  443. ProductionRule& r = rules[sg_68];
  444. r.entries[0].probability = .5f;
  445. r.entries[0].code = sg_hq;
  446. r.entries[1].probability = 1.0f;
  447. r.entries[1].code = sg_qh;
  448. }
  449. //78 -> different combos
  450. {
  451. ProductionRule& r = rules[sg_78];
  452. r.entries[0].probability = .5f;
  453. r.entries[0].code = sg_qhe;
  454. r.entries[1].probability = 1.0f;
  455. r.entries[1].code = sg_hdq;
  456. }
  457. // terminate on these
  458. rules[sg_hdq].makeTerminal();
  459. rules[sg_qhe].makeTerminal();
  460. rules[sg_q].makeTerminal();
  461. rules[sg_dq].makeTerminal();
  462. rules[sg_h].makeTerminal();
  463. rules[sg_e].makeTerminal();
  464. }
  465. // 3 is like 1, but with some trips
  466. void StochasticGrammarDictionary::initRule3(ProductionRule * rules)
  467. {
  468. // break w2 into w,w prob 100
  469. {
  470. ProductionRule& r = rules[sg_w2];
  471. r.entries[0].probability = 1.0f;
  472. r.entries[0].code = sg_ww;
  473. }
  474. // break w into h, h
  475. {
  476. ProductionRule& r = rules[sg_w];
  477. r.entries[0].probability = 1.0f;
  478. r.entries[0].code = sg_hh;
  479. }
  480. // break h into q,q, or h
  481. {
  482. ProductionRule&r = rules[sg_h];
  483. r.entries[0].probability = .75f;
  484. r.entries[0].code = sg_qq;
  485. r.entries[1].probability = 1.0f;
  486. r.entries[1].code = sg_invalid;
  487. }
  488. // stop on q, or make e, or make trips
  489. {
  490. ProductionRule&r = rules[sg_q];
  491. r.entries[0].probability = .3f;
  492. r.entries[0].code = sg_ee;
  493. r.entries[1].probability = .7f;
  494. r.entries[1].code = sg_e3e3e3;
  495. r.entries[2].probability = 1.0f;
  496. r.entries[2].code = sg_invalid;
  497. }
  498. // expand trip 8ths
  499. rules[sg_e3].makeTerminal();
  500. // stop on e, or make sx,
  501. {
  502. ProductionRule&r = rules[sg_e];
  503. r.entries[0].probability = .3f;
  504. r.entries[0].code = sg_sxsx;
  505. r.entries[1].probability = 1.0f;
  506. r.entries[1].code = sg_invalid;
  507. }
  508. rules[sg_sx].makeTerminal();
  509. }
  510. int StochasticGrammarDictionary::getNumGrammars()
  511. {
  512. return 4;
  513. }
  514. StochasticGrammarDictionary::Grammar StochasticGrammarDictionary::getGrammar(int index)
  515. {
  516. if (!_didInitRules)
  517. initRules();
  518. assert(index >= 0 && index < getNumGrammars());
  519. Grammar ret;
  520. ret.firstRule = sg_w2;
  521. ret.numRules = fullRuleTableSize;
  522. switch (index) {
  523. case 0:
  524. ret.rules = _rules0;
  525. break;
  526. case 1:
  527. ret.rules = _rules1;
  528. break;
  529. case 2:
  530. ret.rules = _rules2;
  531. break;
  532. case 3:
  533. ret.rules = _rules3;
  534. break;
  535. default:
  536. assert(false);
  537. }
  538. return ret;
  539. }