AMKhakbaz commited on
Commit
a39a86f
·
verified ·
1 Parent(s): f58f3cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -1
app.py CHANGED
@@ -6,7 +6,9 @@ import plotly.graph_objects as go
6
  from scipy.stats import norm, t
7
  from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
8
  import plotly.figure_factory as ff
9
- import openai
 
 
10
 
11
  def sorting(df):
12
  df.index = list(map(float, df.index))
@@ -597,6 +599,84 @@ def categorize_responses(df, api_key, prompt=None):
597
 
598
  return categorized_df
599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  empty_col1, main_col, empty_col2 = st.columns([1.6, 2.8, 1.6])
601
 
602
  with main_col:
 
6
  from scipy.stats import norm, t
7
  from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
8
  import plotly.figure_factory as ff
9
+ import requests
10
+ import time
11
+
12
 
13
  def sorting(df):
14
  df.index = list(map(float, df.index))
 
599
 
600
  return categorized_df
601
 
602
+ def categorize_responses(initial_prompt: str,
603
+ dataframe: pd.DataFrame,
604
+ id_column: str,
605
+ text_column: str,
606
+ api_key: str,
607
+ max_retries: int = 3,
608
+ delay: float = 1.0) -> pd.DataFrame:
609
+ """
610
+ Categorizes survey responses using Deepseek API.
611
+
612
+ Args:
613
+ initial_prompt: Context/problem statement/employer concerns
614
+ dataframe: DataFrame containing responses
615
+ id_column: Name of column with unique IDs
616
+ text_column: Name of column with text responses
617
+ api_key: Deepseek API key
618
+ max_retries: Maximum API call retries
619
+ delay: Delay between retries in seconds
620
+
621
+ Returns:
622
+ DataFrame with added 'category' column
623
+ """
624
+
625
+ # Validate dataframe structure
626
+ if id_column not in dataframe.columns or text_column not in dataframe.columns:
627
+ raise ValueError("Dataframe must contain specified ID and text columns")
628
+
629
+ # Create API headers
630
+ headers = {
631
+ "Authorization": f"Bearer {api_key}",
632
+ "Content-Type": "application/json"
633
+ }
634
+
635
+ # Define the API endpoint (verify correct endpoint from Deepseek documentation)
636
+ api_url = "https://api.deepseek.com/v1/chat/completions"
637
+
638
+ def get_category(answer: str) -> str:
639
+ """Helper function to get category from API"""
640
+ messages = [
641
+ {
642
+ "role": "system",
643
+ "content": f"{initial_prompt}\n\nCategorize the following response into one of the appropriate categories."
644
+ },
645
+ {
646
+ "role": "user",
647
+ "content": answer
648
+ }
649
+ ]
650
+
651
+ payload = {
652
+ "model": "deepseek-chat", # Verify correct model name
653
+ "messages": messages,
654
+ "temperature": 0.2,
655
+ "max_tokens": 64
656
+ }
657
+
658
+ for attempt in range(max_retries):
659
+ try:
660
+ response = requests.post(api_url, headers=headers, json=payload)
661
+ response.raise_for_status()
662
+
663
+ # Parse response - adjust according to actual API response structure
664
+ result = response.json()['choices'][0]['message']['content'].strip()
665
+ return result
666
+
667
+ except Exception as e:
668
+ if attempt == max_retries - 1:
669
+ print(f"Failed after {max_retries} attempts: {str(e)}")
670
+ return "Error: Categorization failed"
671
+ time.sleep(delay * (attempt + 1))
672
+
673
+ return "Error: Max retries exceeded"
674
+
675
+ # Apply categorization to each response
676
+ dataframe['category'] = dataframe[text_column].apply(get_category)
677
+
678
+ return dataframe
679
+
680
  empty_col1, main_col, empty_col2 = st.columns([1.6, 2.8, 1.6])
681
 
682
  with main_col: