Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,9 @@ import plotly.graph_objects as go
|
|
6 |
from scipy.stats import norm, t
|
7 |
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
|
8 |
import plotly.figure_factory as ff
|
9 |
-
import
|
|
|
|
|
10 |
|
11 |
def sorting(df):
|
12 |
df.index = list(map(float, df.index))
|
@@ -597,6 +599,84 @@ def categorize_responses(df, api_key, prompt=None):
|
|
597 |
|
598 |
return categorized_df
|
599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
600 |
empty_col1, main_col, empty_col2 = st.columns([1.6, 2.8, 1.6])
|
601 |
|
602 |
with main_col:
|
|
|
6 |
from scipy.stats import norm, t
|
7 |
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
|
8 |
import plotly.figure_factory as ff
|
9 |
+
import requests
|
10 |
+
import time
|
11 |
+
|
12 |
|
13 |
def sorting(df):
|
14 |
df.index = list(map(float, df.index))
|
|
|
599 |
|
600 |
return categorized_df
|
601 |
|
602 |
+
def categorize_responses(initial_prompt: str,
|
603 |
+
dataframe: pd.DataFrame,
|
604 |
+
id_column: str,
|
605 |
+
text_column: str,
|
606 |
+
api_key: str,
|
607 |
+
max_retries: int = 3,
|
608 |
+
delay: float = 1.0) -> pd.DataFrame:
|
609 |
+
"""
|
610 |
+
Categorizes survey responses using Deepseek API.
|
611 |
+
|
612 |
+
Args:
|
613 |
+
initial_prompt: Context/problem statement/employer concerns
|
614 |
+
dataframe: DataFrame containing responses
|
615 |
+
id_column: Name of column with unique IDs
|
616 |
+
text_column: Name of column with text responses
|
617 |
+
api_key: Deepseek API key
|
618 |
+
max_retries: Maximum API call retries
|
619 |
+
delay: Delay between retries in seconds
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
DataFrame with added 'category' column
|
623 |
+
"""
|
624 |
+
|
625 |
+
# Validate dataframe structure
|
626 |
+
if id_column not in dataframe.columns or text_column not in dataframe.columns:
|
627 |
+
raise ValueError("Dataframe must contain specified ID and text columns")
|
628 |
+
|
629 |
+
# Create API headers
|
630 |
+
headers = {
|
631 |
+
"Authorization": f"Bearer {api_key}",
|
632 |
+
"Content-Type": "application/json"
|
633 |
+
}
|
634 |
+
|
635 |
+
# Define the API endpoint (verify correct endpoint from Deepseek documentation)
|
636 |
+
api_url = "https://api.deepseek.com/v1/chat/completions"
|
637 |
+
|
638 |
+
def get_category(answer: str) -> str:
|
639 |
+
"""Helper function to get category from API"""
|
640 |
+
messages = [
|
641 |
+
{
|
642 |
+
"role": "system",
|
643 |
+
"content": f"{initial_prompt}\n\nCategorize the following response into one of the appropriate categories."
|
644 |
+
},
|
645 |
+
{
|
646 |
+
"role": "user",
|
647 |
+
"content": answer
|
648 |
+
}
|
649 |
+
]
|
650 |
+
|
651 |
+
payload = {
|
652 |
+
"model": "deepseek-chat", # Verify correct model name
|
653 |
+
"messages": messages,
|
654 |
+
"temperature": 0.2,
|
655 |
+
"max_tokens": 64
|
656 |
+
}
|
657 |
+
|
658 |
+
for attempt in range(max_retries):
|
659 |
+
try:
|
660 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
661 |
+
response.raise_for_status()
|
662 |
+
|
663 |
+
# Parse response - adjust according to actual API response structure
|
664 |
+
result = response.json()['choices'][0]['message']['content'].strip()
|
665 |
+
return result
|
666 |
+
|
667 |
+
except Exception as e:
|
668 |
+
if attempt == max_retries - 1:
|
669 |
+
print(f"Failed after {max_retries} attempts: {str(e)}")
|
670 |
+
return "Error: Categorization failed"
|
671 |
+
time.sleep(delay * (attempt + 1))
|
672 |
+
|
673 |
+
return "Error: Max retries exceeded"
|
674 |
+
|
675 |
+
# Apply categorization to each response
|
676 |
+
dataframe['category'] = dataframe[text_column].apply(get_category)
|
677 |
+
|
678 |
+
return dataframe
|
679 |
+
|
680 |
empty_col1, main_col, empty_col2 = st.columns([1.6, 2.8, 1.6])
|
681 |
|
682 |
with main_col:
|