(args, query_template='{}')
| 73 | return seq |
| 74 | |
| 75 | def get_context(args, query_template='{}'): |
| 76 | tokenizer = get_tokenizer() |
| 77 | terminate_runs = 0 |
| 78 | img_size = 256 if args.generation_task != 'low-level super-resolution' else 128 |
| 79 | ml = max(args.max_position_embeddings, args.max_position_embeddings_finetune) |
| 80 | output_path = args.output_path |
| 81 | |
| 82 | if args.input_source == 'interactive': |
| 83 | assert not args.with_id, '--with-id is only used with file inputs.' |
| 84 | if args.generation_task == 'post-selection': |
| 85 | raise ValueError('post-selection only takes file inputs!') |
| 86 | while True: |
| 87 | raw_text = input("\nPlease Input Query (stop to exit) >>> ") |
| 88 | if not raw_text: |
| 89 | print('Query should not be empty!') |
| 90 | continue |
| 91 | if raw_text == "stop": |
| 92 | return |
| 93 | try: |
| 94 | seq = _parse_and_to_tensor(raw_text, img_size=img_size, query_template=query_template) |
| 95 | except (ValueError, FileNotFoundError) as e: |
| 96 | print(e) |
| 97 | continue |
| 98 | if len(seq) > ml: |
| 99 | print("\nSeq length", len(seq), |
| 100 | f"\nPlease give smaller context than {ml}!") |
| 101 | continue |
| 102 | yield (raw_text, seq, output_path) |
| 103 | else: |
| 104 | with open(args.input_source, 'r') as fin: |
| 105 | inputs = fin.readlines() |
| 106 | for line_no, raw_text in enumerate(inputs): |
| 107 | if line_no % dist.get_world_size() != dist.get_rank(): |
| 108 | continue |
| 109 | rk = dist.get_rank() |
| 110 | print(f'Working on No. {line_no} on {rk}... ') |
| 111 | raw_text = raw_text.strip() |
| 112 | if len(raw_text) == 0: |
| 113 | continue |
| 114 | if args.with_id: # with id |
| 115 | parts = raw_text.split('\t') |
| 116 | output_path = os.path.join(args.output_path, parts[0]) |
| 117 | raw_text = '\t'.join(parts[1:]) |
| 118 | |
| 119 | if args.generation_task == 'post-selection': |
| 120 | parts = raw_text.split('\t') |
| 121 | seqs = [] |
| 122 | for part in parts[1:]: |
| 123 | try: |
| 124 | seq_single = _parse_and_to_tensor('\t'.join([part, parts[0]]), img_size=img_size, query_template=query_template) |
| 125 | seqs.append(seq_single) |
| 126 | except (ValueError, FileNotFoundError) as e: |
| 127 | print(e) |
| 128 | continue |
| 129 | seq = torch.stack(seqs) |
| 130 | else: |
| 131 | try: |
| 132 | seq = _parse_and_to_tensor(raw_text, img_size=img_size, query_template=query_template) |
no test coverage detected