Spaces:
Running
Running
Commit
·
abad9af
1
Parent(s):
683d306
Add Test File
Browse files- agent.py +1 -1
- test_agent.py +59 -0
agent.py
CHANGED
@@ -78,7 +78,7 @@ def create_assistant_tools(cfg):
|
|
78 |
[ask_vehicles, ask_policies]
|
79 |
)
|
80 |
|
81 |
-
def initialize_agent(_cfg, update_func):
|
82 |
electric_vehicle_bot_instructions = """
|
83 |
- You are a helpful research assistant, with expertise in electric vehicles, in conversation with a user.
|
84 |
- Before answering any user query, get sample data from each table in the database, so that you can understand NULL and unique values for each column.
|
|
|
78 |
[ask_vehicles, ask_policies]
|
79 |
)
|
80 |
|
81 |
+
def initialize_agent(_cfg, update_func=None):
|
82 |
electric_vehicle_bot_instructions = """
|
83 |
- You are a helpful research assistant, with expertise in electric vehicles, in conversation with a user.
|
84 |
- Before answering any user query, get sample data from each table in the database, so that you can understand NULL and unique values for each column.
|
test_agent.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import os
|
3 |
+
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from vectara_agent.agent import Agent
|
6 |
+
|
7 |
+
import sqlite3
|
8 |
+
from datasets import load_dataset
|
9 |
+
|
10 |
+
from app import initialize_agent, setup_db
|
11 |
+
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
load_dotenv(override=True)
|
14 |
+
|
15 |
+
class TestAgentResponses(unittest.TestCase):
|
16 |
+
|
17 |
+
def test_responses(self):
|
18 |
+
|
19 |
+
cfg = OmegaConf.create({
|
20 |
+
'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']),
|
21 |
+
'corpus_ids': str(os.environ['VECTARA_CORPUS_IDS']).split(','),
|
22 |
+
'api_keys': str(os.environ['VECTARA_API_KEYS']).split(','),
|
23 |
+
'examples': os.environ.get('QUERY_EXAMPLES', None)
|
24 |
+
})
|
25 |
+
|
26 |
+
setup_db()
|
27 |
+
|
28 |
+
agent = initialize_agent(_cfg=cfg)
|
29 |
+
self.assertIsInstance(agent, Agent)
|
30 |
+
|
31 |
+
# Knows types of electric vehicles
|
32 |
+
type_output = agent.chat('What are the different types of electric vehicles? Only provide the name of each type, nothing else.').lower()
|
33 |
+
ev_types = ['battery', 'hybrid', 'plug-in hybrid', 'fuel cell']
|
34 |
+
|
35 |
+
for ev_type in ev_types:
|
36 |
+
self.assertIn(ev_type, type_output)
|
37 |
+
|
38 |
+
|
39 |
+
# Questions about car models - ev query tool
|
40 |
+
self.assertIn('mach-e', agent.chat('Which EV is made by Mustang? Provide the name of the model only.').lower())
|
41 |
+
self.assertIn('fuel cell', agent.chat('What EV type is the Toyota Mirai? Just give the type name.').lower())
|
42 |
+
|
43 |
+
|
44 |
+
# Incentive query tool tests
|
45 |
+
self.assertIn('no', agent.chat('Does the U.S. Department of Defense offer incentives for purchasing electric vehicles? Only say "yes" or "no".').lower())
|
46 |
+
self.assertIn('2035', agent.chat('At what year must all new passenger vehicles be zero emission vehicles in California? Give the year only.').lower())
|
47 |
+
|
48 |
+
|
49 |
+
# Database Tool questions
|
50 |
+
self.assertIn('king', agent.chat('Which county in the state of Washington had the highest number of EV registrations in 2023? Provide the name only.').lower())
|
51 |
+
self.assertIn('seattle', agent.chat('Which city in the state of Washington had the highest number of EV registrations in 2023? Provide the name only.').lower())
|
52 |
+
self.assertIn('tesla model y', agent.chat('What car was the most popular in Seattle in 2023? Provide the make and model only').lower())
|
53 |
+
|
54 |
+
# Misc. questions
|
55 |
+
self.assertIn('tesla', agent.chat('Which company developed a standard charging port for electric cars in California? Only provide the company name, nothing else').lower())
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
setup_db()
|
59 |
+
unittest.main()
|