()
| 12 | not _tensorflow_hub_available, reason="tensorflow_hub is not installed" |
| 13 | ) |
| 14 | def test_train_tiny(): |
| 15 | command = "textattack train --model distilbert-base-uncased --attack textfooler --dataset rotten_tomatoes --model-max-length 64 --num-epochs 1 --num-clean-epochs 0 --num-train-adv-examples 2" |
| 16 | |
| 17 | # Run command and validate outputs. |
| 18 | result = run_command_and_get_result(command) |
| 19 | |
| 20 | assert result.stdout is not None |
| 21 | assert result.stderr is not None |
| 22 | assert result.returncode == 0 |
| 23 | |
| 24 | stdout = result.stdout.decode().strip() |
| 25 | print("stdout =>", stdout) |
| 26 | stderr = result.stderr.decode().strip() |
| 27 | print("stderr =>", stderr) |
| 28 | |
| 29 | train_args_json_path = re.findall( |
| 30 | r"Wrote original training args to (\S+)\.", stderr |
| 31 | ) |
| 32 | assert len(train_args_json_path) and os.path.exists(train_args_json_path[0]) |
| 33 | |
| 34 | train_acc = re.findall(r"Train accuracy: (\S+)", stderr) |
| 35 | assert train_acc |
| 36 | train_acc = float(train_acc[0][:-1]) # [:-1] removes percent sign |
| 37 | assert train_acc > 60 |
| 38 | |
| 39 | eval_acc = re.findall(r"Eval accuracy: (\S+)", stderr) |
| 40 | assert eval_acc |
| 41 | eval_acc = float(eval_acc[0][:-1]) # [:-1] removes percent sign |
| 42 | assert train_acc > 60 |
nothing calls this directly
no test coverage detected