(cls, train_toks, **kwargs)
| 1497 | class TadmMaxentClassifier(MaxentClassifier): |
| 1498 | @classmethod |
| 1499 | def train(cls, train_toks, **kwargs): |
| 1500 | algorithm = kwargs.get("algorithm", "tao_lmvm") |
| 1501 | trace = kwargs.get("trace", 3) |
| 1502 | encoding = kwargs.get("encoding", None) |
| 1503 | labels = kwargs.get("labels", None) |
| 1504 | sigma = kwargs.get("gaussian_prior_sigma", 0) |
| 1505 | count_cutoff = kwargs.get("count_cutoff", 0) |
| 1506 | max_iter = kwargs.get("max_iter") |
| 1507 | ll_delta = kwargs.get("min_lldelta") |
| 1508 | |
| 1509 | # Construct an encoding from the training data. |
| 1510 | if not encoding: |
| 1511 | encoding = TadmEventMaxentFeatureEncoding.train( |
| 1512 | train_toks, count_cutoff, labels=labels |
| 1513 | ) |
| 1514 | |
| 1515 | trainfile_fd, trainfile_name = tempfile.mkstemp( |
| 1516 | prefix="nltk-tadm-events-", suffix=".gz" |
| 1517 | ) |
| 1518 | weightfile_fd, weightfile_name = tempfile.mkstemp(prefix="nltk-tadm-weights-") |
| 1519 | |
| 1520 | trainfile = gzip_open_unicode(trainfile_name, "w") |
| 1521 | write_tadm_file(train_toks, encoding, trainfile) |
| 1522 | trainfile.close() |
| 1523 | |
| 1524 | options = [] |
| 1525 | options.extend(["-monitor"]) |
| 1526 | options.extend(["-method", algorithm]) |
| 1527 | if sigma: |
| 1528 | options.extend(["-l2", "%.6f" % sigma**2]) |
| 1529 | if max_iter: |
| 1530 | options.extend(["-max_it", "%d" % max_iter]) |
| 1531 | if ll_delta: |
| 1532 | options.extend(["-fatol", "%.6f" % abs(ll_delta)]) |
| 1533 | options.extend(["-events_in", trainfile_name]) |
| 1534 | options.extend(["-params_out", weightfile_name]) |
| 1535 | if trace < 3: |
| 1536 | options.extend(["2>&1"]) |
| 1537 | else: |
| 1538 | options.extend(["-summary"]) |
| 1539 | |
| 1540 | call_tadm(options) |
| 1541 | |
| 1542 | with open(weightfile_name) as weightfile: |
| 1543 | weights = parse_tadm_weights(weightfile) |
| 1544 | |
| 1545 | os.remove(trainfile_name) |
| 1546 | os.remove(weightfile_name) |
| 1547 | |
| 1548 | # Convert from base-e to base-2 weights. |
| 1549 | weights *= numpy.log2(numpy.e) |
| 1550 | |
| 1551 | # Build the classifier |
| 1552 | return cls(encoding, weights) |
| 1553 | |
| 1554 | |
| 1555 | ###################################################################### |
nothing calls this directly
no test coverage detected