File size: 2,790 Bytes
cfd3735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Test LLM Bash functionality."""
import sys

import pytest

from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser
from langchain.schema import OutputParserException
from tests.unit_tests.llms.fake_llm import FakeLLM

_SAMPLE_CODE = """
Unrelated text
```bash
echo hello
```
Unrelated text
"""


_SAMPLE_CODE_2_LINES = """
Unrelated text
```bash
echo hello

echo world
```
Unrelated text
"""


@pytest.fixture
def output_parser() -> BashOutputParser:
    """Output parser for testing."""
    return BashOutputParser()


@pytest.mark.skipif(
    sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_simple_question() -> None:
    """Test simple question that should not need python."""
    question = "Please write a bash script that prints 'Hello World' to the console."
    prompt = _PROMPT_TEMPLATE.format(question=question)
    queries = {prompt: "```bash\nexpr 1 + 1\n```"}
    fake_llm = FakeLLM(queries=queries)
    fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
    output = fake_llm_bash_chain.run(question)
    assert output == "2\n"


def test_get_code(output_parser: BashOutputParser) -> None:
    """Test the parser."""
    code_lines = output_parser.parse(_SAMPLE_CODE)
    code = [c for c in code_lines if c.strip()]
    assert code == code_lines
    assert code == ["echo hello"]

    code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES)
    assert code_lines == ["echo hello", "echo hello", "echo world"]


def test_parsing_error() -> None:
    """Test that LLM Output without a bash block raises an exce"""
    question = "Please echo 'hello world' to the terminal."
    prompt = _PROMPT_TEMPLATE.format(question=question)
    queries = {
        prompt: """
```text
echo 'hello world'
```
"""
    }
    fake_llm = FakeLLM(queries=queries)
    fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
    with pytest.raises(OutputParserException):
        fake_llm_bash_chain.run(question)


def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None:
    text = """
Unrelated text
```bash
echo hello
ls && pwd && ls
```

```python
print("hello")
```

```bash
echo goodbye
```
"""
    code_lines = output_parser.parse(text)
    assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"]


def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None:
    """Test that backticks w/o a newline are ignored."""
    text = """
Unrelated text
```bash
echo hello
echo "```bash is in this string```"
```
"""
    code_lines = output_parser.parse(text)
    assert code_lines == ["echo hello", 'echo "```bash is in this string```"']