Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Imports | |
| import asyncio | |
| import os | |
| import openai | |
| import wandb | |
| from typing import List, Optional | |
| # from pydantic import BaseModel, Field | |
| # from langchain.prompts import ChatPromptTemplate | |
| # from langchain.pydantic_v1 import BaseModel | |
| # from langchain.utils.openai_functions import convert_pydantic_to_openai_function | |
| from llama_index.tools import FunctionTool | |
| from llama_index.vector_stores.types import ( | |
| VectorStoreInfo, | |
| MetadataInfo, | |
| ExactMatchFilter, | |
| MetadataFilters, | |
| ) | |
| from llama_index.agent import OpenAIAgent | |
| from llama_index.retrievers import VectorIndexRetriever | |
| from llama_index.query_engine import RetrieverQueryEngine | |
| from typing import List, Tuple, Any | |
| from pydantic import BaseModel, Field | |
| from llama_index import load_index_from_storage | |
| from llama_index import set_global_handler | |
| import llama_index | |
| from llama_index.embeddings import OpenAIEmbedding | |
| from llama_index import ServiceContext | |
| from llama_index.llms import OpenAI | |
| from llama_index import GPTVectorStoreIndex | |
| set_global_handler("wandb", run_args={"project": "final-project-v1"}) | |
| wandb_callback = llama_index.global_handler | |
| run = wandb.init() | |
| artifact = run.use_artifact('chrisalexiuk/llamaindex-demo-v1/wiki-index:v0', type='storage_context') | |
| artifact_dir = artifact.download() | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| openai.api_key = os.environ['OPENAI_API_KEY'] | |
| top_k = 3 | |
| vector_store_info = VectorStoreInfo( | |
| content_info="transcripts of earnings calls", | |
| metadata_info=[MetadataInfo( | |
| name="title", | |
| type="str", | |
| description="Title of the earnings call", | |
| ), | |
| MetadataInfo( | |
| name="period", | |
| type="str", | |
| description="Period of the earnings call" | |
| ), | |
| MetadataInfo( | |
| name="ticker", | |
| type="str", | |
| description="Ticker of the company" | |
| ), | |
| MetadataInfo( | |
| name="year", | |
| type="str", | |
| description="Year of the earnings call" | |
| ), | |
| MetadataInfo( | |
| name="quarter", | |
| type="str", | |
| description="Quarter of the earnings call" | |
| ), | |
| MetadataInfo( | |
| name="path", | |
| type="str", | |
| description="Path to the earnings call" | |
| ), | |
| ]) | |
| class AutoRetrieveModel(BaseModel): | |
| query: str = Field(..., description="natural language query string") | |
| filter_key_list: List[str] = Field( | |
| ..., description="List of metadata filter field names" | |
| ) | |
| filter_value_list: List[str] = Field( | |
| ..., | |
| description=( | |
| "List of metadata filter field values (corresponding to names specified in filter_key_list)" | |
| ) | |
| ) | |
| embed_model = OpenAIEmbedding() | |
| chunk_size = 500 | |
| llm = OpenAI( | |
| temperature=0, | |
| model="gpt-4" ### YOUR CODE HERE | |
| ) | |
| service_context = ServiceContext.from_defaults( | |
| llm=llm, | |
| chunk_size=chunk_size, | |
| embed_model=embed_model, | |
| ) | |
| index = GPTVectorStoreIndex.from_documents([], service_context=service_context) | |
| # Main function to extract information | |
| async def extract_information(): | |
| # Make sure to use a recent model that supports tools | |
| storage_context = wandb_callback.load_storage_context( | |
| artifact_url=artifact_dir + "/index_store.json" | |
| ) | |
| index = load_index_from_storage(storage_context, service_context=service_context) | |
| def auto_retrieve_fn( | |
| query: str, filter_key_list: List[str], filter_value_list: List[str] | |
| ): | |
| """Auto retrieval function. | |
| Performs auto-retrieval from a vector database, and then applies a set of filters. | |
| """ | |
| query = query or "Query" | |
| exact_match_filters = [ | |
| ExactMatchFilter(key=k, value=v) | |
| for k, v in zip(filter_key_list, filter_value_list) | |
| ] | |
| retriever = VectorIndexRetriever( | |
| index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k | |
| ) | |
| query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context) | |
| response = query_engine.query(query) | |
| return str(response) | |
| auto_retrieve_tool = FunctionTool.from_defaults( | |
| fn=auto_retrieve_fn, | |
| name="earnings-transcripts", | |
| description="Earnings Bot", | |
| fn_schema=AutoRetrieveModel | |
| ) | |
| agent = OpenAIAgent.from_tools( | |
| tools=[auto_retrieve_tool], | |
| ) | |
| return agent | |
| # if __name__ == "__main__": | |
| # text = "Who is the CEO of MSFT." | |
| # chain = extract_information() | |
| # print(str(chain.chat(text))) | |
| # async def extract_information_async(message: str): | |
| # return str(chain.chat(text)) | |
| # async def main(): | |
| # res = await extract_information_async(text) | |
| # print(res) | |
| # asyncio.run(main()) | |