kashyaparun commited on
Commit
d3f8a88
·
verified ·
1 Parent(s): 7400c9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -44
app.py CHANGED
@@ -29,6 +29,9 @@ from pydantic import BaseModel, Field
29
  import litellm
30
  from langchain.tools import Tool
31
 
 
 
 
32
  # Configure logging
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger(__name__)
@@ -69,22 +72,13 @@ st.write("---")
69
  # Sidebar for API key configuration
70
  with st.sidebar:
71
  st.title("⚙️ Configuration")
72
- api_key_source = st.radio("Select API Key Provider:",
73
- ["Google (Gemini)", "OpenAI"],
74
- help="Choose which AI provider to use")
75
 
76
- if api_key_source == "Google (Gemini)":
77
- api_key = st.text_input("Enter your Gemini API Key", type="password",
78
- help="Required for the AI model to function")
79
- if api_key:
80
- os.environ["GEMINI_API_KEY"] = api_key
81
- os.environ["GOOGLE_API_KEY"] = api_key
82
- else:
83
- api_key = st.text_input("Enter your OpenAI API Key", type="password",
84
- help="Required for the AI model to function")
85
- if api_key:
86
- os.environ["OPENAI_API_KEY"] = "sk-proj-iQ8piK0xt54XBKEI4nDzCg9CZE7a13xZqCaN1B78zqZTyhBwrXOCjfMjNWG0w1gprhdj6_"
87
-
88
  st.divider()
89
 
90
  # Reset button
@@ -93,7 +87,6 @@ with st.sidebar:
93
  del st.session_state[key]
94
  st.rerun()
95
 
96
-
97
  #---------------------------- Utility Functions ----------------------------#
98
 
99
  def extract_text_from_pdf(file):
@@ -398,6 +391,8 @@ class CaseBreakdownCrew:
398
  self.api_key = api_key
399
 
400
  def create_metadata_agent(self):
 
 
401
  return Agent(
402
  role="Metadata Analyzer",
403
  goal="Extract title and author information from document content",
@@ -409,6 +404,7 @@ class CaseBreakdownCrew:
409
  )
410
 
411
  def create_content_generator_agent(self):
 
412
  return Agent(
413
  role="Case Study Content Generator",
414
  goal="Generate comprehensive case analysis content based on section requirements",
@@ -420,6 +416,7 @@ class CaseBreakdownCrew:
420
  )
421
 
422
  def create_content_reviewer_agent(self):
 
423
  return Agent(
424
  role="Content Quality Reviewer",
425
  goal="Evaluate and score content for quality, relevance, and depth",
@@ -505,6 +502,7 @@ class CaseBreakdownCrew:
505
  agents=[self.create_metadata_agent()],
506
  tasks=[metadata_task],
507
  process=Process.sequential,
 
508
  verbose=False
509
  )
510
  result = crew.kickoff()
@@ -607,16 +605,9 @@ def create_teaching_plan_crew(file_paths, llm_provider="gemini"):
607
  tracker.set_placeholder(st.empty())
608
 
609
  # Initialize LLM based on provider
610
- if llm_provider == "gemini":
611
- my_llm = LLM(
612
- model='gemini/gemini-2.0-flash',
613
- api_key=os.environ.get("GEMINI_API_KEY")
614
- )
615
- else:
616
- my_llm = LLM(
617
- model='gpt-4-turbo',
618
- api_key=os.environ.get("OPENAI_API_KEY")
619
- )
620
 
621
  # Define agents with callbacks for UI updates
622
  pdf_analyzer = Agent(
@@ -708,24 +699,15 @@ def create_teaching_plan_crew(file_paths, llm_provider="gemini"):
708
  #---------------------------- Board Plan Generator ----------------------------#
709
 
710
  class BoardPlanAnalyzer:
711
- def __init__(self, llm_provider="gemini"):
712
- if llm_provider == "gemini":
713
- api_key = os.environ.get('GEMINI_API_KEY')
714
- self.model = "gemini/gemini-2.0-flash"
715
- else:
716
- api_key = os.environ.get('OPENAI_API_KEY')
717
- self.model = "gpt-4-turbo"
718
-
719
  if not api_key:
720
- raise ValueError(f"{llm_provider.capitalize()} API key not found")
721
-
722
- if llm_provider == "gemini":
723
- os.environ['GEMINI_API_KEY'] = api_key
724
- else:
725
- os.environ['OPENAI_API_KEY'] = api_key
726
 
727
  litellm.set_verbose = True
728
-
729
  # Create agents
730
  self.create_agents()
731
 
@@ -745,6 +727,7 @@ class BoardPlanAnalyzer:
745
  description="Extracts text content from PDF files"
746
  )],
747
  allow_delegation=False,
 
748
  verbose=True
749
  )
750
 
@@ -761,6 +744,7 @@ class BoardPlanAnalyzer:
761
  description="Analyzes case study and creates structured board plan"
762
  )],
763
  allow_delegation=False,
 
764
  verbose=True
765
  )
766
 
@@ -823,7 +807,7 @@ class BoardPlanAnalyzer:
823
 
824
  try:
825
  response = litellm.completion(
826
- model=self.model,
827
  messages=messages,
828
  response_format={"type": "json_object"}
829
  )
@@ -1098,7 +1082,7 @@ if st.session_state.uploaded_files:
1098
  progress_bar = progress_placeholder.progress(0)
1099
 
1100
  # Select LLM provider
1101
- llm_provider = "gemini" if api_key_source == "Google (Gemini)" else "openai"
1102
 
1103
  # Update progress
1104
  progress_bar.progress(10)
@@ -1189,7 +1173,7 @@ if st.session_state.uploaded_files:
1189
  if st.button("Generate Board Plan", key="board_plan_button"):
1190
  try:
1191
  # Select LLM provider
1192
- llm_provider = "gemini" if api_key_source == "Google (Gemini)" else "openai"
1193
 
1194
  # Initialize the board plan analyzer
1195
  analyzer = BoardPlanAnalyzer(llm_provider=llm_provider)
 
29
  import litellm
30
  from langchain.tools import Tool
31
 
32
+ LLM._get_litellm_model_name = lambda self, model_name: f"gemini/{model_name}" if not "/" in model_name else model_name
33
+ os.environ["LITELLM_MODEL_DEFAULT_PROVIDER"] = "gemini"
34
+
35
  # Configure logging
36
  logging.basicConfig(level=logging.INFO)
37
  logger = logging.getLogger(__name__)
 
72
  # Sidebar for API key configuration
73
  with st.sidebar:
74
  st.title("⚙️ Configuration")
 
 
 
75
 
76
+ api_key = st.text_input("Enter your Gemini API Key", type="password",
77
+ help="Required for the AI model to function")
78
+ if api_key:
79
+ os.environ["GEMINI_API_KEY"] = api_key
80
+ os.environ["GOOGLE_API_KEY"] = api_key
81
+
 
 
 
 
 
 
82
  st.divider()
83
 
84
  # Reset button
 
87
  del st.session_state[key]
88
  st.rerun()
89
 
 
90
  #---------------------------- Utility Functions ----------------------------#
91
 
92
  def extract_text_from_pdf(file):
 
391
  self.api_key = api_key
392
 
393
  def create_metadata_agent(self):
394
+ self.api_key = api_key
395
+ self.llm = LLM(model='gemini/gemini-2.0-flash', api_key=self.api_key) # Create a Gemini LLM instance
396
  return Agent(
397
  role="Metadata Analyzer",
398
  goal="Extract title and author information from document content",
 
404
  )
405
 
406
  def create_content_generator_agent(self):
407
+ llm = LLM(model='gemini/gemini-2.0-flash', api_key=self.api_key)
408
  return Agent(
409
  role="Case Study Content Generator",
410
  goal="Generate comprehensive case analysis content based on section requirements",
 
416
  )
417
 
418
  def create_content_reviewer_agent(self):
419
+ llm = LLM(model='gemini/gemini-2.0-flash', api_key=self.api_key)
420
  return Agent(
421
  role="Content Quality Reviewer",
422
  goal="Evaluate and score content for quality, relevance, and depth",
 
502
  agents=[self.create_metadata_agent()],
503
  tasks=[metadata_task],
504
  process=Process.sequential,
505
+ llm=LLM(model='gemini/gemini-2.0-flash', api_key=self.api_key),
506
  verbose=False
507
  )
508
  result = crew.kickoff()
 
605
  tracker.set_placeholder(st.empty())
606
 
607
  # Initialize LLM based on provider
608
+ my_llm = LLM(model='gemini/gemini-2.0-flash',
609
+ api_key=os.environ.get("GEMINI_API_KEY")
610
+ )
 
 
 
 
 
 
 
611
 
612
  # Define agents with callbacks for UI updates
613
  pdf_analyzer = Agent(
 
699
  #---------------------------- Board Plan Generator ----------------------------#
700
 
701
  class BoardPlanAnalyzer:
702
+ def __init__(self):
703
+ api_key = os.environ.get('GEMINI_API_KEY')
 
 
 
 
 
 
704
  if not api_key:
705
+ raise ValueError("Gemini API key not found")
706
+ # Create an LLM instance configured for Gemini
707
+ self.llm = LLM(model='gemini/gemini-2.0-flash', api_key=api_key)
 
 
 
708
 
709
  litellm.set_verbose = True
710
+
711
  # Create agents
712
  self.create_agents()
713
 
 
727
  description="Extracts text content from PDF files"
728
  )],
729
  allow_delegation=False,
730
+ llm=self.llm,
731
  verbose=True
732
  )
733
 
 
744
  description="Analyzes case study and creates structured board plan"
745
  )],
746
  allow_delegation=False,
747
+ llm=self.llm,
748
  verbose=True
749
  )
750
 
 
807
 
808
  try:
809
  response = litellm.completion(
810
+ model=self.llm,
811
  messages=messages,
812
  response_format={"type": "json_object"}
813
  )
 
1082
  progress_bar = progress_placeholder.progress(0)
1083
 
1084
  # Select LLM provider
1085
+ llm_provider = "gemini"
1086
 
1087
  # Update progress
1088
  progress_bar.progress(10)
 
1173
  if st.button("Generate Board Plan", key="board_plan_button"):
1174
  try:
1175
  # Select LLM provider
1176
+ llm_provider = "gemini"
1177
 
1178
  # Initialize the board plan analyzer
1179
  analyzer = BoardPlanAnalyzer(llm_provider=llm_provider)