MCPcopy
hub / github.com/QData/TextAttack / test_train_tiny

Function test_train_tiny

tests/test_command_line/test_train.py:14–42  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

12 not _tensorflow_hub_available, reason="tensorflow_hub is not installed"
13)
14def test_train_tiny():
15 command = "textattack train --model distilbert-base-uncased --attack textfooler --dataset rotten_tomatoes --model-max-length 64 --num-epochs 1 --num-clean-epochs 0 --num-train-adv-examples 2"
16
17 # Run command and validate outputs.
18 result = run_command_and_get_result(command)
19
20 assert result.stdout is not None
21 assert result.stderr is not None
22 assert result.returncode == 0
23
24 stdout = result.stdout.decode().strip()
25 print("stdout =>", stdout)
26 stderr = result.stderr.decode().strip()
27 print("stderr =>", stderr)
28
29 train_args_json_path = re.findall(
30 r"Wrote original training args to (\S+)\.", stderr
31 )
32 assert len(train_args_json_path) and os.path.exists(train_args_json_path[0])
33
34 train_acc = re.findall(r"Train accuracy: (\S+)", stderr)
35 assert train_acc
36 train_acc = float(train_acc[0][:-1]) # [:-1] removes percent sign
37 assert train_acc > 60
38
39 eval_acc = re.findall(r"Eval accuracy: (\S+)", stderr)
40 assert eval_acc
41 eval_acc = float(eval_acc[0][:-1]) # [:-1] removes percent sign
42 assert train_acc > 60

Callers

nothing calls this directly

Calls 3

stripMethod · 0.80
decodeMethod · 0.45

Tested by

no test coverage detected