Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload 5 files
Browse files- gradio_smol.py +24 -0
- requirements.txt +5 -0
- sql_data.py +80 -0
- tools.py +43 -0
- utils.py +80 -0
    	
        gradio_smol.py
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from tools import FreightAgent, EXAMPLE_QUERIES
         | 
| 3 | 
            +
            from utils import initialize_database
         | 
| 4 | 
            +
            from smolagents import CodeAgent, HfApiModel, GradioUI
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from dotenv import load_dotenv
         | 
| 7 | 
            +
            from sql_data import sql_query, get_schema, get_csv_as_dataframe
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Load environment variables
         | 
| 10 | 
            +
            load_dotenv()
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Initialize the database if it doesn't exist
         | 
| 13 | 
            +
            if not os.path.exists("freights.db"):
         | 
| 14 | 
            +
                csv_url = "https://huggingface.co/datasets/sasu-SpidR/fretmaritime/resolve/main/freights.csv"
         | 
| 15 | 
            +
                initialize_database(csv_url)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # Create the main agent
         | 
| 18 | 
            +
            model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
         | 
| 19 | 
            +
            model = HfApiModel(model_id=model_id, token=os.environ["HF_API_KEY"])
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            agent = CodeAgent(tools=[sql_query, get_schema, get_csv_as_dataframe], model=model)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            if __name__ == "__main__":
         | 
| 24 | 
            +
                GradioUI(agent).launch()
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pandas
         | 
| 2 | 
            +
            sqlalchemy
         | 
| 3 | 
            +
            smolagents
         | 
| 4 | 
            +
            python-dotenv
         | 
| 5 | 
            +
            gradio
         | 
    	
        sql_data.py
    ADDED
    
    | @@ -0,0 +1,80 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from sqlalchemy import create_engine, text
         | 
| 2 | 
            +
            from smolagents import tool
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import pandas as pd
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Create database engine
         | 
| 7 | 
            +
            engine = create_engine("sqlite:///freights.db")
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @tool
         | 
| 11 | 
            +
            def sql_query(query: str) -> str:
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                Allows you to perform SQL queries on the freights table. Returns a string representation of the result.
         | 
| 14 | 
            +
                The table is named 'freights'. Its description is as follows:
         | 
| 15 | 
            +
                    Columns:
         | 
| 16 | 
            +
                    - departure: DateTime (Date and time of departure)
         | 
| 17 | 
            +
                    - origin_port_locode: String (Origin port code)
         | 
| 18 | 
            +
                    - origin_port_name: String (Name of the origin port)
         | 
| 19 | 
            +
                    - destination_port: String (Destination port code)
         | 
| 20 | 
            +
                    - destination_port_name: String (Name of the destination port)
         | 
| 21 | 
            +
                    - dv20rate: Float (Rate for 20ft container in USD)
         | 
| 22 | 
            +
                    - dv40rate: Float (Rate for 40ft container in USD)
         | 
| 23 | 
            +
                    - currency: String (Currency of the rates)
         | 
| 24 | 
            +
                    - inserted_on: DateTime (Date when the rate was inserted)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Args:
         | 
| 27 | 
            +
                    query: The query to perform. This should be correct SQL.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Returns:
         | 
| 30 | 
            +
                    A string representation of the result of the query.
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                try:
         | 
| 33 | 
            +
                    with engine.connect() as con:
         | 
| 34 | 
            +
                        result = con.execute(text(query))
         | 
| 35 | 
            +
                        rows = [dict(row._mapping) for row in result]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                        if not rows:
         | 
| 38 | 
            +
                            return "Aucun résultat trouvé."
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                        # Convert to markdown table
         | 
| 41 | 
            +
                        headers = list(rows[0].keys())
         | 
| 42 | 
            +
                        table = "| " + " | ".join(headers) + " |\n"
         | 
| 43 | 
            +
                        table += "| " + " | ".join(["---" for _ in headers]) + " |\n"
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                        for row in rows:
         | 
| 46 | 
            +
                            table += "| " + " | ".join(str(row[h]) for h in headers) + " |\n"
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                        return table
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                except Exception as e:
         | 
| 51 | 
            +
                    return f"Error executing query: {str(e)}"
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            @tool
         | 
| 55 | 
            +
            def get_schema() -> str:
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                Returns the schema of the freights table.
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                return """
         | 
| 60 | 
            +
                Table: freights
         | 
| 61 | 
            +
                Columns:
         | 
| 62 | 
            +
                - departure: DateTime (Date and time of departure)
         | 
| 63 | 
            +
                - origin_port_locode: String (Origin port code)
         | 
| 64 | 
            +
                - origin_port_name: String (Name of the origin port)
         | 
| 65 | 
            +
                - destination_port: String (Destination port code)
         | 
| 66 | 
            +
                - destination_port_name: String (Name of the destination port)
         | 
| 67 | 
            +
                - dv20rate: Float (Rate for 20ft container in USD)
         | 
| 68 | 
            +
                - dv40rate: Float (Rate for 40ft container in USD)
         | 
| 69 | 
            +
                - currency: String (Currency of the rates)
         | 
| 70 | 
            +
                - inserted_on: DateTime (Date when the rate was inserted)
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            @tool
         | 
| 75 | 
            +
            def get_csv_as_dataframe() -> str:
         | 
| 76 | 
            +
                """
         | 
| 77 | 
            +
                Returns a string representation of the freights table as a CSV file.
         | 
| 78 | 
            +
                """
         | 
| 79 | 
            +
                df = pd.read_sql_table("freights", engine)
         | 
| 80 | 
            +
                return df.to_csv(index=False)
         | 
    	
        tools.py
    ADDED
    
    | @@ -0,0 +1,43 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from smolagents import CodeAgent, HfApiModel
         | 
| 2 | 
            +
            from sql_data import sql_query, get_schema
         | 
| 3 | 
            +
            from sqlalchemy import create_engine, inspect, text
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from dotenv import load_dotenv
         | 
| 6 | 
            +
            from typing import Dict, List, Any
         | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Load environment variables
         | 
| 10 | 
            +
            load_dotenv()
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Example queries that the agent can handle
         | 
| 13 | 
            +
            EXAMPLE_QUERIES = [
         | 
| 14 | 
            +
                "Quels sont les tarifs moyens des conteneurs 20ft et 40ft entre tous les ports ?",
         | 
| 15 | 
            +
                "Quels sont les ports d'origine les plus fréquents ?",
         | 
| 16 | 
            +
                "Montre-moi les routes avec des tarifs élevés pour les conteneurs 40ft",
         | 
| 17 | 
            +
                "Quelle est l'évolution des prix au fil du temps pour la route Surabaya vers Nansha ?",
         | 
| 18 | 
            +
                "Quelles sont les destinations disponibles depuis Shanghai ?",
         | 
| 19 | 
            +
            ]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class FreightAgent:
         | 
| 23 | 
            +
                def __init__(self):
         | 
| 24 | 
            +
                    self.setup_agent()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def setup_agent(self) -> None:
         | 
| 27 | 
            +
                    """
         | 
| 28 | 
            +
                    Initialize the CodeAgent with SQL tools.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    Create a CodeAgent with two tools: `sql_query` and `get_schema`.
         | 
| 31 | 
            +
                    `sql_query` allows to perform SQL queries on the freights table.
         | 
| 32 | 
            +
                    `get_schema` returns the schema of the freights table.
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    self.agent = CodeAgent(
         | 
| 35 | 
            +
                        tools=[sql_query, get_schema],
         | 
| 36 | 
            +
                        model=HfApiModel("meta-llama/Llama-3.1-8B-Instruct"),
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def query(self, question: str) -> str:
         | 
| 40 | 
            +
                    """
         | 
| 41 | 
            +
                    Ask a question about the freight data in natural language
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
                    return self.agent.run(question)
         | 
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,80 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import pandas as pd
         | 
| 2 | 
            +
            import sqlite3
         | 
| 3 | 
            +
            import requests
         | 
| 4 | 
            +
            from typing import List, Dict, Any
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from sqlalchemy import create_engine, Column, Float, String, Integer, DateTime
         | 
| 7 | 
            +
            from sqlalchemy.ext.declarative import declarative_base
         | 
| 8 | 
            +
            from sqlalchemy.orm import sessionmaker
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Create base class for declarative models
         | 
| 11 | 
            +
            Base = declarative_base()
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class Freight(Base):
         | 
| 15 | 
            +
                """SQLAlchemy model for freight data"""
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                __tablename__ = "freights"
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                id = Column(Integer, primary_key=True)
         | 
| 20 | 
            +
                departure = Column(DateTime)
         | 
| 21 | 
            +
                origin_port_locode = Column(String)
         | 
| 22 | 
            +
                origin_port_name = Column(String)
         | 
| 23 | 
            +
                destination_port = Column(String)
         | 
| 24 | 
            +
                destination_port_name = Column(String)
         | 
| 25 | 
            +
                dv20rate = Column(Float)
         | 
| 26 | 
            +
                dv40rate = Column(Float)
         | 
| 27 | 
            +
                currency = Column(String)
         | 
| 28 | 
            +
                inserted_on = Column(DateTime)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def download_csv(url: str, local_path: str = "freights.csv") -> str:
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                Download CSV file from Hugging Face and save it locally
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                response = requests.get(url)
         | 
| 36 | 
            +
                with open(local_path, "wb") as f:
         | 
| 37 | 
            +
                    f.write(response.content)
         | 
| 38 | 
            +
                return local_path
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def create_database(db_name: str = "freights.db") -> None:
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                Create SQLite database and necessary tables
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                engine = create_engine(f"sqlite:///{db_name}")
         | 
| 46 | 
            +
                Base.metadata.create_all(engine)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def load_csv_to_db(csv_path: str, db_name: str = "freights.db") -> None:
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                Load CSV data into SQLite database
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                # Read CSV
         | 
| 54 | 
            +
                df = pd.read_csv(csv_path, parse_dates=["departure", "inserted_on"])
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                # Connect to database
         | 
| 57 | 
            +
                engine = create_engine(f"sqlite:///{db_name}")
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # Save to database
         | 
| 60 | 
            +
                df.to_sql("freights", engine, if_exists="replace", index=False)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def initialize_database(csv_url: str) -> None:
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                Initialize the database by downloading CSV and loading data.
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                Args:
         | 
| 68 | 
            +
                    csv_url: URL of the CSV file to download and load.
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                # Download CSV
         | 
| 71 | 
            +
                csv_path = download_csv(csv_url)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                # Create and load database
         | 
| 74 | 
            +
                create_database()
         | 
| 75 | 
            +
                load_csv_to_db(csv_path)
         | 
| 76 | 
            +
                print("Database initialized.")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                # Clean up CSV file
         | 
| 79 | 
            +
                if os.path.exists(csv_path):
         | 
| 80 | 
            +
                    os.remove(csv_path)
         | 

