MCPcopy Index your code
hub / github.com/THUDM/GLM / encode

Method encode

tasks/superglue/pvp.py:477–526  ·  view source on GitHub ↗

Encode an input example using this pattern-verbalizer pair. :param example: the input example to encode :param priming: whether to use this example for priming :param labeled: if ``priming=True``, whether the label should be appended to this example :return:

(self, example: InputExample, priming: bool = False, labeled: bool = False)

Source from the content-addressed store, hash-verified

475 return []
476
477 def encode(self, example: InputExample, priming: bool = False, labeled: bool = False):
478 """
479 Encode an input example using this pattern-verbalizer pair.
480
481 :param example: the input example to encode
482 :param priming: whether to use this example for priming
483 :param labeled: if ``priming=True``, whether the label should be appended to this example
484 :return: A tuple, consisting of a list of input ids and a list of token type ids
485 """
486 if self.continuous_prompt or self.pattern_id < 2:
487 return super().encode(example, priming=priming, labeled=labeled)
488 if not priming:
489 assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true"
490
491 tokenizer = self.tokenizer
492 premise = self.remove_final_punc(self.shortenable(example.text_a))
493 choice1 = " " + self.remove_final_punc(self.lowercase_first(example.meta['choice1']))
494 choice2 = " " + self.remove_final_punc(self.lowercase_first(example.meta['choice2']))
495 question = example.meta['question']
496 assert question in ['cause', 'effect']
497 answer = " because" if question == 'cause' else " so"
498 answer_ids = [get_verbalization_ids(answer, tokenizer, force_single_token=True)]
499 if self.is_multi_token:
500 answer_ids.append(tokenizer.get_command('eop').Id)
501
502 ids_list, positions_list, sep_list, mask_list, target_list = [], [], [], [], []
503
504 for choice in [choice1, choice2]:
505 parts = ['"', choice1[1:], '" or "', choice2[1:], '"?', premise, [self.mask], choice]
506 parts = [x if isinstance(x, tuple) else (x, False) for x in parts]
507 parts = [(tokenizer.EncodeAsIds(x).tokenization if isinstance(x, str) else x, s) for x, s in parts if
508 x]
509 self.num_truncated += self.truncate(parts, None, answer_ids, max_length=self.max_seq_length)
510 tokens_a = [token_id for part, _ in parts for token_id in part]
511 data = build_input_from_ids(tokens_a, None, answer_ids, self.max_seq_length, self.tokenizer, args=self.args,
512 add_cls=True, add_sep=False, add_piece=True)
513 ids, types, paddings, position_ids, sep, target_ids, loss_masks = data
514 ids_list.append(ids)
515 positions_list.append(position_ids)
516 sep_list.append(sep)
517 target_list.append(target_ids)
518 mask_list.append(loss_masks)
519 if example.label is not None:
520 label = self.label_list.index(example.label)
521 else:
522 label = 0
523 sample = build_sample(ids_list, positions=positions_list, masks=sep_list, label=label,
524 logit_mask=mask_list, target=target_list,
525 unique_id=example.guid)
526 return sample
527
528
529class WscPVP(PVP):

Callers

nothing calls this directly

Calls 11

build_input_from_idsFunction · 0.90
build_sampleFunction · 0.90
get_verbalization_idsFunction · 0.85
remove_final_puncMethod · 0.80
shortenableMethod · 0.80
lowercase_firstMethod · 0.80
appendMethod · 0.80
get_commandMethod · 0.80
truncateMethod · 0.80
encodeMethod · 0.45
EncodeAsIdsMethod · 0.45

Tested by

no test coverage detected