Spaces:
Sleeping
Sleeping
| # MIT License | |
| # | |
| # Copyright (c) 2024 dataforgood | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # Standard imports | |
| import logging | |
| import uuid | |
| import pandas as pd | |
| # External imports | |
| from IPython.display import display | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| from langchain_openai import ChatOpenAI | |
| from country_by_country.utils import constants | |
| class LLMCleaner: | |
| def __init__(self, **kwargs: dict) -> None: | |
| """ | |
| Builds a table cleaner, by extracting clean data from tables | |
| extracted during table extraction stage. | |
| The kwargs given to the constructor are directly propagated | |
| to the LLMCleaner constructor. | |
| You are free to define any parameter LLMCleaner recognizes. | |
| """ | |
| self.kwargs = kwargs | |
| self.type = "llm_cleaner" | |
| self.openai_model = self.kwargs["openai_model"] | |
| def __call__(self, asset: dict) -> dict: | |
| logging.info("\nKicking off cleaning stage...") | |
| logging.info(f"Cleaning type: {self.type}, with params: {self.kwargs}") | |
| logging.info( | |
| f"Input extraction type: {asset['type']}, with params: {asset['params']}", | |
| ) | |
| # Extract tables from previous stage | |
| tables = asset["tables"] | |
| logging.info(f"Pulling {len(tables)} tables from extraction stage") | |
| # Convert tables to html to add to LLM prompt | |
| html_tables = [table.to_html() for table in tables] | |
| # Define our LLM model | |
| model = ChatOpenAI(temperature=0, model=self.openai_model) | |
| # ---------- CHAIN 1/2 - Pull countries from each table ---------- | |
| logging.info("Starting chain 1/2: extracting country names from tables") | |
| # Output should have this model (a list of country names) | |
| class CountryNames(BaseModel): | |
| country_names: list[str] = Field( | |
| description="Exhaustive list of countries with financial data in the table", | |
| enum=constants.COUNTRIES, | |
| ) | |
| # Output should be a JSON with above schema | |
| parser1 = JsonOutputParser(pydantic_object=CountryNames) | |
| # Prompt includes one extracted table and some JSON output formatting instructions | |
| prompt1 = PromptTemplate( | |
| template="Extract an exhaustive list of countries from the following table " | |
| + "in html format:\n{table}\n{format_instructions}", | |
| input_variables=["table"], | |
| partial_variables={ | |
| "format_instructions": parser1.get_format_instructions(), | |
| }, | |
| ) | |
| # Chain | |
| chain1 = {"table": lambda x: x} | prompt1 | model | parser1 | |
| # Run it | |
| responses1 = chain1.batch(html_tables, {"max_concurrency": 4}) | |
| # Extract country lists from responses | |
| country_lists = [resp["country_names"] for resp in responses1] | |
| # ---------- CHAIN 2/2 - Pull financial data for each country ---------- | |
| logging.info("Starting chain 2/2: extracting financial data from tables") | |
| # Define country data model | |
| class Country(BaseModel): | |
| """Financial data about a country""" | |
| jur_name: str = Field(..., description="Name of the country") | |
| total_revenues: float | None = Field(None, description="Total revenues") | |
| profit_before_tax: float | None = Field( | |
| None, | |
| description="Amount of profit (or loss) before tax", | |
| ) | |
| tax_paid: float | None = Field(None, description="Income tax paid") | |
| tax_accrued: float | None = Field(None, description="Accrued tax") | |
| employees: float | None = Field(None, description="Number of employees") | |
| stated_capital: float | None = Field(None, description="Stated capital") | |
| accumulated_earnings: float | None = Field( | |
| None, | |
| description="Accumulated earnings", | |
| ) | |
| tangible_assets: float | None = Field( | |
| None, | |
| description="Tangible assets other than cash and cash equivalent", | |
| ) | |
| # Output should have this model (a list of country objects) | |
| class Countries(BaseModel): | |
| """Extracting financial data for each country""" | |
| countries: list[Country] | |
| # Output should be a JSON with above schema | |
| parser2 = PydanticOutputParser(pydantic_object=Countries) | |
| # Prompt includes one extracted table and some JSON output formatting instructions | |
| template = ( | |
| """You are an assistant tasked with extracting financial """ | |
| + """data about {country_list} from the following table in html format:\n | |
| {table}\n | |
| {format_instructions} | |
| """ | |
| ) | |
| # Set up prompt | |
| prompt = PromptTemplate.from_template( | |
| template, | |
| partial_variables={ | |
| "format_instructions": parser2.get_format_instructions(), | |
| }, | |
| ) | |
| # Chain | |
| chain2 = ( | |
| {"table": lambda x: x[0], "country_list": lambda x: x[1]} | |
| | prompt | |
| | model.with_structured_output(Countries) | |
| ) | |
| # Run it | |
| responses2 = chain2.batch( | |
| list(zip(html_tables, country_lists, strict=True)), | |
| {"max_concurrency": 4}, | |
| ) | |
| # Merge the tables into one dataframe | |
| df = pd.concat( | |
| [pd.json_normalize(resp.dict()["countries"]) for resp in responses2], | |
| ).reset_index(drop=True) | |
| # Display | |
| display(df) | |
| # Create asset | |
| new_asset = { | |
| "id": uuid.uuid4(), | |
| "type": self.type, | |
| "params": self.kwargs, | |
| "table": df, | |
| } | |
| return new_asset | |