rajat5ranjan's picture
update
b7f6cd9 verified
raw
history blame
3.12 kB
import streamlit as st
import os
import getpass
from langchain import PromptTemplate
from langchain import hub
from langchain.docstore.document import Document
from langchain.document_loaders import WebBaseLoader
from langchain.schema import StrOutputParser
from langchain.schema.prompt_template import format_document
from langchain.schema.runnable import RunnablePassthrough
from langchain.vectorstores import Chroma
import google.generativeai as genai
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
GOOGLE_API_KEY=os.environ['GOOGLE_API_KEY']
loader = WebBaseLoader("https://www.google.com/finance/?hl=en")
docs = loader.load()
gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
# Save to disk
vectorstore = Chroma.from_documents(
documents=docs, # Data
embedding=gemini_embeddings, # Embedding model
persist_directory="./chroma_db" # Directory to save data
)
vectorstore_disk = Chroma(
persist_directory="./chroma_db", # Directory of db
embedding_function=gemini_embeddings # Embedding model
)
retriever = vectorstore_disk.as_retriever(search_kwargs={"k": 1})
# If there is no environment variable set for the API key, you can pass the API
# key to the parameter `google_api_key` of the `ChatGoogleGenerativeAI` function:
# `google_api_key="key"`.
llm = ChatGoogleGenerativeAI(model="gemini-pro",google_api_key = GOOGLE_API_KEY)
llm_prompt_template = """You are an assistant for question-answering tasks.
Use the following context to answer the question.
If you don't know the answer, just say that you don't know.
Use five sentences maximum and keep the answer concise.\n
Question: {question} \nContext: {context} \nAnswer:"""
llm_prompt = PromptTemplate.from_template(llm_prompt_template)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| llm_prompt
| llm
| StrOutputParser()
)
prompt = st.text_input("Enter Prompt","What is the best stocks for the next few weeks")
res = rag_chain.invoke(prompt)
st.write(res)
# If there is no environment variable set for the API key, you can pass the API
# key to the parameter `google_api_key` of the `GoogleGenerativeAIEmbeddings`
# function: `google_api_key = "key"`.
# gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
# # Save to disk
# vectorstore = Chroma.from_documents(
# documents=docs, # Data
# embedding=gemini_embeddings, # Embedding model
# persist_directory="./chroma_db" # Directory to save data
# )
# vectorstore_disk = Chroma(
# persist_directory="./chroma_db", # Directory of db
# embedding_function=gemini_embeddings # Embedding model