Spaces:
Running
Running
Upload 5 files
Browse files- gradio_smol.py +24 -0
- requirements.txt +5 -0
- sql_data.py +80 -0
- tools.py +43 -0
- utils.py +80 -0
gradio_smol.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from tools import FreightAgent, EXAMPLE_QUERIES
|
3 |
+
from utils import initialize_database
|
4 |
+
from smolagents import CodeAgent, HfApiModel, GradioUI
|
5 |
+
import os
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from sql_data import sql_query, get_schema, get_csv_as_dataframe
|
8 |
+
|
9 |
+
# Load environment variables
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
# Initialize the database if it doesn't exist
|
13 |
+
if not os.path.exists("freights.db"):
|
14 |
+
csv_url = "https://huggingface.co/datasets/sasu-SpidR/fretmaritime/resolve/main/freights.csv"
|
15 |
+
initialize_database(csv_url)
|
16 |
+
|
17 |
+
# Create the main agent
|
18 |
+
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
19 |
+
model = HfApiModel(model_id=model_id, token=os.environ["HF_API_KEY"])
|
20 |
+
|
21 |
+
agent = CodeAgent(tools=[sql_query, get_schema, get_csv_as_dataframe], model=model)
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
GradioUI(agent).launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas
|
2 |
+
sqlalchemy
|
3 |
+
smolagents
|
4 |
+
python-dotenv
|
5 |
+
gradio
|
sql_data.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import create_engine, text
|
2 |
+
from smolagents import tool
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
# Create database engine
|
7 |
+
engine = create_engine("sqlite:///freights.db")
|
8 |
+
|
9 |
+
|
10 |
+
@tool
|
11 |
+
def sql_query(query: str) -> str:
|
12 |
+
"""
|
13 |
+
Allows you to perform SQL queries on the freights table. Returns a string representation of the result.
|
14 |
+
The table is named 'freights'. Its description is as follows:
|
15 |
+
Columns:
|
16 |
+
- departure: DateTime (Date and time of departure)
|
17 |
+
- origin_port_locode: String (Origin port code)
|
18 |
+
- origin_port_name: String (Name of the origin port)
|
19 |
+
- destination_port: String (Destination port code)
|
20 |
+
- destination_port_name: String (Name of the destination port)
|
21 |
+
- dv20rate: Float (Rate for 20ft container in USD)
|
22 |
+
- dv40rate: Float (Rate for 40ft container in USD)
|
23 |
+
- currency: String (Currency of the rates)
|
24 |
+
- inserted_on: DateTime (Date when the rate was inserted)
|
25 |
+
|
26 |
+
Args:
|
27 |
+
query: The query to perform. This should be correct SQL.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
A string representation of the result of the query.
|
31 |
+
"""
|
32 |
+
try:
|
33 |
+
with engine.connect() as con:
|
34 |
+
result = con.execute(text(query))
|
35 |
+
rows = [dict(row._mapping) for row in result]
|
36 |
+
|
37 |
+
if not rows:
|
38 |
+
return "Aucun résultat trouvé."
|
39 |
+
|
40 |
+
# Convert to markdown table
|
41 |
+
headers = list(rows[0].keys())
|
42 |
+
table = "| " + " | ".join(headers) + " |\n"
|
43 |
+
table += "| " + " | ".join(["---" for _ in headers]) + " |\n"
|
44 |
+
|
45 |
+
for row in rows:
|
46 |
+
table += "| " + " | ".join(str(row[h]) for h in headers) + " |\n"
|
47 |
+
|
48 |
+
return table
|
49 |
+
|
50 |
+
except Exception as e:
|
51 |
+
return f"Error executing query: {str(e)}"
|
52 |
+
|
53 |
+
|
54 |
+
@tool
|
55 |
+
def get_schema() -> str:
|
56 |
+
"""
|
57 |
+
Returns the schema of the freights table.
|
58 |
+
"""
|
59 |
+
return """
|
60 |
+
Table: freights
|
61 |
+
Columns:
|
62 |
+
- departure: DateTime (Date and time of departure)
|
63 |
+
- origin_port_locode: String (Origin port code)
|
64 |
+
- origin_port_name: String (Name of the origin port)
|
65 |
+
- destination_port: String (Destination port code)
|
66 |
+
- destination_port_name: String (Name of the destination port)
|
67 |
+
- dv20rate: Float (Rate for 20ft container in USD)
|
68 |
+
- dv40rate: Float (Rate for 40ft container in USD)
|
69 |
+
- currency: String (Currency of the rates)
|
70 |
+
- inserted_on: DateTime (Date when the rate was inserted)
|
71 |
+
"""
|
72 |
+
|
73 |
+
|
74 |
+
@tool
|
75 |
+
def get_csv_as_dataframe() -> str:
|
76 |
+
"""
|
77 |
+
Returns a string representation of the freights table as a CSV file.
|
78 |
+
"""
|
79 |
+
df = pd.read_sql_table("freights", engine)
|
80 |
+
return df.to_csv(index=False)
|
tools.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import CodeAgent, HfApiModel
|
2 |
+
from sql_data import sql_query, get_schema
|
3 |
+
from sqlalchemy import create_engine, inspect, text
|
4 |
+
import os
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from typing import Dict, List, Any
|
7 |
+
import json
|
8 |
+
|
9 |
+
# Load environment variables
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
# Example queries that the agent can handle
|
13 |
+
EXAMPLE_QUERIES = [
|
14 |
+
"Quels sont les tarifs moyens des conteneurs 20ft et 40ft entre tous les ports ?",
|
15 |
+
"Quels sont les ports d'origine les plus fréquents ?",
|
16 |
+
"Montre-moi les routes avec des tarifs élevés pour les conteneurs 40ft",
|
17 |
+
"Quelle est l'évolution des prix au fil du temps pour la route Surabaya vers Nansha ?",
|
18 |
+
"Quelles sont les destinations disponibles depuis Shanghai ?",
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
class FreightAgent:
|
23 |
+
def __init__(self):
|
24 |
+
self.setup_agent()
|
25 |
+
|
26 |
+
def setup_agent(self) -> None:
|
27 |
+
"""
|
28 |
+
Initialize the CodeAgent with SQL tools.
|
29 |
+
|
30 |
+
Create a CodeAgent with two tools: `sql_query` and `get_schema`.
|
31 |
+
`sql_query` allows to perform SQL queries on the freights table.
|
32 |
+
`get_schema` returns the schema of the freights table.
|
33 |
+
"""
|
34 |
+
self.agent = CodeAgent(
|
35 |
+
tools=[sql_query, get_schema],
|
36 |
+
model=HfApiModel("meta-llama/Llama-3.1-8B-Instruct"),
|
37 |
+
)
|
38 |
+
|
39 |
+
def query(self, question: str) -> str:
|
40 |
+
"""
|
41 |
+
Ask a question about the freight data in natural language
|
42 |
+
"""
|
43 |
+
return self.agent.run(question)
|
utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import sqlite3
|
3 |
+
import requests
|
4 |
+
from typing import List, Dict, Any
|
5 |
+
import os
|
6 |
+
from sqlalchemy import create_engine, Column, Float, String, Integer, DateTime
|
7 |
+
from sqlalchemy.ext.declarative import declarative_base
|
8 |
+
from sqlalchemy.orm import sessionmaker
|
9 |
+
|
10 |
+
# Create base class for declarative models
|
11 |
+
Base = declarative_base()
|
12 |
+
|
13 |
+
|
14 |
+
class Freight(Base):
|
15 |
+
"""SQLAlchemy model for freight data"""
|
16 |
+
|
17 |
+
__tablename__ = "freights"
|
18 |
+
|
19 |
+
id = Column(Integer, primary_key=True)
|
20 |
+
departure = Column(DateTime)
|
21 |
+
origin_port_locode = Column(String)
|
22 |
+
origin_port_name = Column(String)
|
23 |
+
destination_port = Column(String)
|
24 |
+
destination_port_name = Column(String)
|
25 |
+
dv20rate = Column(Float)
|
26 |
+
dv40rate = Column(Float)
|
27 |
+
currency = Column(String)
|
28 |
+
inserted_on = Column(DateTime)
|
29 |
+
|
30 |
+
|
31 |
+
def download_csv(url: str, local_path: str = "freights.csv") -> str:
|
32 |
+
"""
|
33 |
+
Download CSV file from Hugging Face and save it locally
|
34 |
+
"""
|
35 |
+
response = requests.get(url)
|
36 |
+
with open(local_path, "wb") as f:
|
37 |
+
f.write(response.content)
|
38 |
+
return local_path
|
39 |
+
|
40 |
+
|
41 |
+
def create_database(db_name: str = "freights.db") -> None:
|
42 |
+
"""
|
43 |
+
Create SQLite database and necessary tables
|
44 |
+
"""
|
45 |
+
engine = create_engine(f"sqlite:///{db_name}")
|
46 |
+
Base.metadata.create_all(engine)
|
47 |
+
|
48 |
+
|
49 |
+
def load_csv_to_db(csv_path: str, db_name: str = "freights.db") -> None:
|
50 |
+
"""
|
51 |
+
Load CSV data into SQLite database
|
52 |
+
"""
|
53 |
+
# Read CSV
|
54 |
+
df = pd.read_csv(csv_path, parse_dates=["departure", "inserted_on"])
|
55 |
+
|
56 |
+
# Connect to database
|
57 |
+
engine = create_engine(f"sqlite:///{db_name}")
|
58 |
+
|
59 |
+
# Save to database
|
60 |
+
df.to_sql("freights", engine, if_exists="replace", index=False)
|
61 |
+
|
62 |
+
|
63 |
+
def initialize_database(csv_url: str) -> None:
|
64 |
+
"""
|
65 |
+
Initialize the database by downloading CSV and loading data.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
csv_url: URL of the CSV file to download and load.
|
69 |
+
"""
|
70 |
+
# Download CSV
|
71 |
+
csv_path = download_csv(csv_url)
|
72 |
+
|
73 |
+
# Create and load database
|
74 |
+
create_database()
|
75 |
+
load_csv_to_db(csv_path)
|
76 |
+
print("Database initialized.")
|
77 |
+
|
78 |
+
# Clean up CSV file
|
79 |
+
if os.path.exists(csv_path):
|
80 |
+
os.remove(csv_path)
|