github_search_visualizations / text_visualization.py
lambdaofgod's picture
updated embeddings
c752e68
from typing import Dict, Any
from sklearn.feature_extraction.text import TfidfVectorizer
import wordcloud
from pydantic import BaseModel, Field
import numpy as np
import PIL
import plotly.express as px
import pandas as pd
import datasets
class WordCloudExtractor(BaseModel):
max_words: int = 50
wordcloud_params: Dict[str, Any] = Field(default_factory=dict)
tfidf_params: Dict[str, Any] = Field(
default_factory=lambda: {"stop_words": "english"}
)
def extract_wordcloud_image(self, texts) -> PIL.Image.Image:
frequencies = self._extract_frequencies(
texts, self.max_words, tfidf_params=self.tfidf_params
)
wc = wordcloud.WordCloud(**self.wordcloud_params).generate_from_frequencies(
frequencies
)
return wc.to_image()
@classmethod
def _extract_frequencies(
cls, texts, max_words=100, tfidf_params: dict = {}
) -> Dict[str, float]:
"""
Extract word frequencies from a corpus using TF-IDF vectorization
and generate word cloud frequencies.
Args:
texts: List of text documents
max_features: Maximum number of words to include
Returns:
Dictionary of word frequencies suitable for WordCloud
"""
# Initialize TF-IDF vectorizer
tfidf = TfidfVectorizer(max_features=max_words, **tfidf_params)
# Fit and transform the texts
tfidf_matrix = tfidf.fit_transform(texts)
# Get feature names (words)
feature_names = tfidf.get_feature_names_out()
# Calculate mean TF-IDF scores across documents
mean_tfidf = np.array(tfidf_matrix.mean(axis=0)).flatten()
# Create frequency dictionary
frequencies = dict(zip(feature_names, mean_tfidf))
return frequencies
class EmbeddingVisualizer(BaseModel):
display_df: pd.DataFrame
plot_kwargs: Dict[str, Any] = Field(
default_factory=lambda: dict(
range_x=(3, 16.5),
range_y=(-3, 11),
width=1200,
height=800,
x="x",
y="y",
template="plotly_white",
)
)
def make_embedding_plots(
self, color_col=None, hover_data=["name"], filter_df_fn=None
):
"""
plots Plotly scatterplot of UMAP embeddings
"""
display_df = self.display_df
if filter_df_fn is not None:
display_df = filter_df_fn(display_df)
display_df = display_df.sort_values("representation", ascending=False)
readme_df = display_df[
display_df["representation"].isin(
["readme", "code2doc_generated_readme", "task"]
)
]
raw_df = display_df[
display_df["representation"].isin(
["dependency_signature", "selected_code", "task"]
)
]
dependency_df = display_df[
display_df["representation"].isin(
[
"repository_signature",
"dependency_signature",
"generated_tasks",
"task",
]
)
]
plots = [
self._make_task_and_repos_scatterplot(df, hover_data, color_col)
for df in [readme_df, raw_df, dependency_df]
]
return dict(
zip(
[
"READMEs",
"Basic representations",
"Dependency graph based representations",
],
plots,
)
)
def _make_task_and_repos_scatterplot(self, df, hover_data, color_col):
# Set opacity and symbol based on is_task
df["size"] = df["is_task"].apply(lambda x: 0.25 if x else 0.1)
df["symbol"] = df["is_task"].apply(int)
combined_fig = px.scatter(
df,
hover_name="name",
hover_data=hover_data,
color=color_col,
color_discrete_sequence=px.colors.qualitative.Set1,
opacity=0.5,
**self.plot_kwargs,
)
combined_fig.data = combined_fig.data[::-1]
return combined_fig
def make_task_area_scatterplot(self, n_areas=6):
display_df = self.display_df
displayed_tasks_df = display_df[
display_df["representation"] == "task"
].sort_values("representation")
pwc_tasks_df = datasets.load_dataset(
"lambdaofgod/pwc_github_search", data_files="paperswithcode_tasks.csv"
)["train"].to_pandas()
displayed_tasks_df = displayed_tasks_df.merge(
pwc_tasks_df,
left_on="name",
right_on="task",
)
displayed_tasks_df = displayed_tasks_df[
displayed_tasks_df["area"].isin(
displayed_tasks_df["area"].value_counts().head(n_areas).index
)
]
tasks_fig = px.scatter(
displayed_tasks_df,
color="area",
hover_data=["name"],
opacity=0.7,
**self.plot_kwargs,
)
print("N DISPLAYED TASKS", len(displayed_tasks_df))
return tasks_fig
class Config:
arbitrary_types_allowed = True