seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1] context_length: first non(-1)s
(
model,
seq,
args,
mems=None,
invalid_slices=[],
**kwargs)
| 62 | return tokens, attention_mask, position_ids |
| 63 | |
| 64 | def filling_sequence( |
| 65 | model, |
| 66 | seq, |
| 67 | args, |
| 68 | mems=None, |
| 69 | invalid_slices=[], |
| 70 | **kwargs): |
| 71 | ''' |
| 72 | seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1] |
| 73 | context_length: first non(-1)s |
| 74 | ''' |
| 75 | tokenizer = get_tokenizer() |
| 76 | device = seq.device |
| 77 | assert len(seq.shape) == 1 |
| 78 | out_seq_length = len(seq) |
| 79 | # building the initial tokens, attention_mask, and position_ids |
| 80 | context_length = 0 |
| 81 | offset = 100000 |
| 82 | |
| 83 | invalid_slices = [slice(0, tokenizer.img_tokenizer.num_tokens)] |
| 84 | |
| 85 | while seq[context_length] >= 0: |
| 86 | # change what to generate |
| 87 | if seq[context_length] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]: |
| 88 | invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] |
| 89 | elif seq[context_length] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]: |
| 90 | invalid_slices = [ |
| 91 | slice(0, tokenizer.img_tokenizer.num_tokens), |
| 92 | slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)] |
| 93 | |
| 94 | if seq[context_length] == tokenizer['[ROI2]']: |
| 95 | offset = context_length |
| 96 | context_length += 1 |
| 97 | tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args) |
| 98 | |
| 99 | counter = context_length - 1 # == len(tokens) - 1 |
| 100 | index = 0 # len(mems) |
| 101 | if mems is None: |
| 102 | mems = [] |
| 103 | score = [0] # sum log likelihood for beams |
| 104 | |
| 105 | if args.is_sparse == 2: |
| 106 | tokenizer = get_tokenizer() |
| 107 | img_txt_sep = tokenizer.img_tokenizer.num_tokens |
| 108 | img_indices_bool = (tokens < img_txt_sep) |
| 109 | txt_indices_bool = (~img_indices_bool) |
| 110 | elif args.is_sparse == 0: |
| 111 | txt_indices_bool = img_indices_bool = None |
| 112 | else: |
| 113 | raise ValueError('set is_sparse==2 for inference.') |
| 114 | |
| 115 | while counter < (out_seq_length - 1): |
| 116 | # Now, we want to generate seq[counter + 1] |
| 117 | # token[:, index: counter+1] are just added. |
| 118 | |
| 119 | if seq[counter + 1] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]: |
| 120 | invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] |
| 121 | elif seq[counter + 1] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]: |
no test coverage detected