salim4n commited on
Commit
c577353
·
verified ·
1 Parent(s): adf4865

Upload 5 files

Browse files
Files changed (5) hide show
  1. gradio_smol.py +24 -0
  2. requirements.txt +5 -0
  3. sql_data.py +80 -0
  4. tools.py +43 -0
  5. utils.py +80 -0
gradio_smol.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from tools import FreightAgent, EXAMPLE_QUERIES
3
+ from utils import initialize_database
4
+ from smolagents import CodeAgent, HfApiModel, GradioUI
5
+ import os
6
+ from dotenv import load_dotenv
7
+ from sql_data import sql_query, get_schema, get_csv_as_dataframe
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # Initialize the database if it doesn't exist
13
+ if not os.path.exists("freights.db"):
14
+ csv_url = "https://huggingface.co/datasets/sasu-SpidR/fretmaritime/resolve/main/freights.csv"
15
+ initialize_database(csv_url)
16
+
17
+ # Create the main agent
18
+ model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
19
+ model = HfApiModel(model_id=model_id, token=os.environ["HF_API_KEY"])
20
+
21
+ agent = CodeAgent(tools=[sql_query, get_schema, get_csv_as_dataframe], model=model)
22
+
23
+ if __name__ == "__main__":
24
+ GradioUI(agent).launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pandas
2
+ sqlalchemy
3
+ smolagents
4
+ python-dotenv
5
+ gradio
sql_data.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine, text
2
+ from smolagents import tool
3
+ import os
4
+ import pandas as pd
5
+
6
+ # Create database engine
7
+ engine = create_engine("sqlite:///freights.db")
8
+
9
+
10
+ @tool
11
+ def sql_query(query: str) -> str:
12
+ """
13
+ Allows you to perform SQL queries on the freights table. Returns a string representation of the result.
14
+ The table is named 'freights'. Its description is as follows:
15
+ Columns:
16
+ - departure: DateTime (Date and time of departure)
17
+ - origin_port_locode: String (Origin port code)
18
+ - origin_port_name: String (Name of the origin port)
19
+ - destination_port: String (Destination port code)
20
+ - destination_port_name: String (Name of the destination port)
21
+ - dv20rate: Float (Rate for 20ft container in USD)
22
+ - dv40rate: Float (Rate for 40ft container in USD)
23
+ - currency: String (Currency of the rates)
24
+ - inserted_on: DateTime (Date when the rate was inserted)
25
+
26
+ Args:
27
+ query: The query to perform. This should be correct SQL.
28
+
29
+ Returns:
30
+ A string representation of the result of the query.
31
+ """
32
+ try:
33
+ with engine.connect() as con:
34
+ result = con.execute(text(query))
35
+ rows = [dict(row._mapping) for row in result]
36
+
37
+ if not rows:
38
+ return "Aucun résultat trouvé."
39
+
40
+ # Convert to markdown table
41
+ headers = list(rows[0].keys())
42
+ table = "| " + " | ".join(headers) + " |\n"
43
+ table += "| " + " | ".join(["---" for _ in headers]) + " |\n"
44
+
45
+ for row in rows:
46
+ table += "| " + " | ".join(str(row[h]) for h in headers) + " |\n"
47
+
48
+ return table
49
+
50
+ except Exception as e:
51
+ return f"Error executing query: {str(e)}"
52
+
53
+
54
+ @tool
55
+ def get_schema() -> str:
56
+ """
57
+ Returns the schema of the freights table.
58
+ """
59
+ return """
60
+ Table: freights
61
+ Columns:
62
+ - departure: DateTime (Date and time of departure)
63
+ - origin_port_locode: String (Origin port code)
64
+ - origin_port_name: String (Name of the origin port)
65
+ - destination_port: String (Destination port code)
66
+ - destination_port_name: String (Name of the destination port)
67
+ - dv20rate: Float (Rate for 20ft container in USD)
68
+ - dv40rate: Float (Rate for 40ft container in USD)
69
+ - currency: String (Currency of the rates)
70
+ - inserted_on: DateTime (Date when the rate was inserted)
71
+ """
72
+
73
+
74
+ @tool
75
+ def get_csv_as_dataframe() -> str:
76
+ """
77
+ Returns a string representation of the freights table as a CSV file.
78
+ """
79
+ df = pd.read_sql_table("freights", engine)
80
+ return df.to_csv(index=False)
tools.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, HfApiModel
2
+ from sql_data import sql_query, get_schema
3
+ from sqlalchemy import create_engine, inspect, text
4
+ import os
5
+ from dotenv import load_dotenv
6
+ from typing import Dict, List, Any
7
+ import json
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # Example queries that the agent can handle
13
+ EXAMPLE_QUERIES = [
14
+ "Quels sont les tarifs moyens des conteneurs 20ft et 40ft entre tous les ports ?",
15
+ "Quels sont les ports d'origine les plus fréquents ?",
16
+ "Montre-moi les routes avec des tarifs élevés pour les conteneurs 40ft",
17
+ "Quelle est l'évolution des prix au fil du temps pour la route Surabaya vers Nansha ?",
18
+ "Quelles sont les destinations disponibles depuis Shanghai ?",
19
+ ]
20
+
21
+
22
+ class FreightAgent:
23
+ def __init__(self):
24
+ self.setup_agent()
25
+
26
+ def setup_agent(self) -> None:
27
+ """
28
+ Initialize the CodeAgent with SQL tools.
29
+
30
+ Create a CodeAgent with two tools: `sql_query` and `get_schema`.
31
+ `sql_query` allows to perform SQL queries on the freights table.
32
+ `get_schema` returns the schema of the freights table.
33
+ """
34
+ self.agent = CodeAgent(
35
+ tools=[sql_query, get_schema],
36
+ model=HfApiModel("meta-llama/Llama-3.1-8B-Instruct"),
37
+ )
38
+
39
+ def query(self, question: str) -> str:
40
+ """
41
+ Ask a question about the freight data in natural language
42
+ """
43
+ return self.agent.run(question)
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import sqlite3
3
+ import requests
4
+ from typing import List, Dict, Any
5
+ import os
6
+ from sqlalchemy import create_engine, Column, Float, String, Integer, DateTime
7
+ from sqlalchemy.ext.declarative import declarative_base
8
+ from sqlalchemy.orm import sessionmaker
9
+
10
+ # Create base class for declarative models
11
+ Base = declarative_base()
12
+
13
+
14
+ class Freight(Base):
15
+ """SQLAlchemy model for freight data"""
16
+
17
+ __tablename__ = "freights"
18
+
19
+ id = Column(Integer, primary_key=True)
20
+ departure = Column(DateTime)
21
+ origin_port_locode = Column(String)
22
+ origin_port_name = Column(String)
23
+ destination_port = Column(String)
24
+ destination_port_name = Column(String)
25
+ dv20rate = Column(Float)
26
+ dv40rate = Column(Float)
27
+ currency = Column(String)
28
+ inserted_on = Column(DateTime)
29
+
30
+
31
+ def download_csv(url: str, local_path: str = "freights.csv") -> str:
32
+ """
33
+ Download CSV file from Hugging Face and save it locally
34
+ """
35
+ response = requests.get(url)
36
+ with open(local_path, "wb") as f:
37
+ f.write(response.content)
38
+ return local_path
39
+
40
+
41
+ def create_database(db_name: str = "freights.db") -> None:
42
+ """
43
+ Create SQLite database and necessary tables
44
+ """
45
+ engine = create_engine(f"sqlite:///{db_name}")
46
+ Base.metadata.create_all(engine)
47
+
48
+
49
+ def load_csv_to_db(csv_path: str, db_name: str = "freights.db") -> None:
50
+ """
51
+ Load CSV data into SQLite database
52
+ """
53
+ # Read CSV
54
+ df = pd.read_csv(csv_path, parse_dates=["departure", "inserted_on"])
55
+
56
+ # Connect to database
57
+ engine = create_engine(f"sqlite:///{db_name}")
58
+
59
+ # Save to database
60
+ df.to_sql("freights", engine, if_exists="replace", index=False)
61
+
62
+
63
+ def initialize_database(csv_url: str) -> None:
64
+ """
65
+ Initialize the database by downloading CSV and loading data.
66
+
67
+ Args:
68
+ csv_url: URL of the CSV file to download and load.
69
+ """
70
+ # Download CSV
71
+ csv_path = download_csv(csv_url)
72
+
73
+ # Create and load database
74
+ create_database()
75
+ load_csv_to_db(csv_path)
76
+ print("Database initialized.")
77
+
78
+ # Clean up CSV file
79
+ if os.path.exists(csv_path):
80
+ os.remove(csv_path)