|
from typing import Dict, Any |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
import wordcloud |
|
from pydantic import BaseModel, Field |
|
import numpy as np |
|
import PIL |
|
import plotly.express as px |
|
import pandas as pd |
|
import datasets |
|
|
|
|
|
class WordCloudExtractor(BaseModel): |
|
max_words: int = 50 |
|
wordcloud_params: Dict[str, Any] = Field(default_factory=dict) |
|
tfidf_params: Dict[str, Any] = Field( |
|
default_factory=lambda: {"stop_words": "english"} |
|
) |
|
|
|
def extract_wordcloud_image(self, texts) -> PIL.Image.Image: |
|
frequencies = self._extract_frequencies( |
|
texts, self.max_words, tfidf_params=self.tfidf_params |
|
) |
|
wc = wordcloud.WordCloud(**self.wordcloud_params).generate_from_frequencies( |
|
frequencies |
|
) |
|
return wc.to_image() |
|
|
|
@classmethod |
|
def _extract_frequencies( |
|
cls, texts, max_words=100, tfidf_params: dict = {} |
|
) -> Dict[str, float]: |
|
""" |
|
Extract word frequencies from a corpus using TF-IDF vectorization |
|
and generate word cloud frequencies. |
|
|
|
Args: |
|
texts: List of text documents |
|
max_features: Maximum number of words to include |
|
|
|
Returns: |
|
Dictionary of word frequencies suitable for WordCloud |
|
""" |
|
|
|
tfidf = TfidfVectorizer(max_features=max_words, **tfidf_params) |
|
|
|
|
|
tfidf_matrix = tfidf.fit_transform(texts) |
|
|
|
|
|
feature_names = tfidf.get_feature_names_out() |
|
|
|
|
|
mean_tfidf = np.array(tfidf_matrix.mean(axis=0)).flatten() |
|
|
|
|
|
frequencies = dict(zip(feature_names, mean_tfidf)) |
|
|
|
return frequencies |
|
|
|
|
|
class EmbeddingVisualizer(BaseModel): |
|
display_df: pd.DataFrame |
|
plot_kwargs: Dict[str, Any] = Field( |
|
default_factory=lambda: dict( |
|
range_x=(3, 16.5), |
|
range_y=(-3, 11), |
|
width=1200, |
|
height=800, |
|
x="x", |
|
y="y", |
|
template="plotly_white", |
|
) |
|
) |
|
|
|
def make_embedding_plots( |
|
self, color_col=None, hover_data=["name"], filter_df_fn=None |
|
): |
|
""" |
|
plots Plotly scatterplot of UMAP embeddings |
|
""" |
|
display_df = self.display_df |
|
if filter_df_fn is not None: |
|
display_df = filter_df_fn(display_df) |
|
|
|
display_df = display_df.sort_values("representation", ascending=False) |
|
readme_df = display_df[ |
|
display_df["representation"].isin( |
|
["readme", "code2doc_generated_readme", "task"] |
|
) |
|
] |
|
raw_df = display_df[ |
|
display_df["representation"].isin( |
|
["dependency_signature", "selected_code", "task"] |
|
) |
|
] |
|
dependency_df = display_df[ |
|
display_df["representation"].isin( |
|
[ |
|
"repository_signature", |
|
"dependency_signature", |
|
"generated_tasks", |
|
"task", |
|
] |
|
) |
|
] |
|
|
|
plots = [ |
|
self._make_task_and_repos_scatterplot(df, hover_data, color_col) |
|
for df in [readme_df, raw_df, dependency_df] |
|
] |
|
return dict( |
|
zip( |
|
[ |
|
"READMEs", |
|
"Basic representations", |
|
"Dependency graph based representations", |
|
], |
|
plots, |
|
) |
|
) |
|
|
|
def _make_task_and_repos_scatterplot(self, df, hover_data, color_col): |
|
|
|
df["size"] = df["is_task"].apply(lambda x: 0.25 if x else 0.1) |
|
df["symbol"] = df["is_task"].apply(int) |
|
|
|
combined_fig = px.scatter( |
|
df, |
|
hover_name="name", |
|
hover_data=hover_data, |
|
color=color_col, |
|
color_discrete_sequence=px.colors.qualitative.Set1, |
|
opacity=0.5, |
|
**self.plot_kwargs, |
|
) |
|
combined_fig.data = combined_fig.data[::-1] |
|
|
|
return combined_fig |
|
|
|
def make_task_area_scatterplot(self, n_areas=6): |
|
display_df = self.display_df |
|
displayed_tasks_df = display_df[ |
|
display_df["representation"] == "task" |
|
].sort_values("representation") |
|
pwc_tasks_df = datasets.load_dataset( |
|
"lambdaofgod/pwc_github_search", data_files="paperswithcode_tasks.csv" |
|
)["train"].to_pandas() |
|
displayed_tasks_df = displayed_tasks_df.merge( |
|
pwc_tasks_df, |
|
left_on="name", |
|
right_on="task", |
|
) |
|
displayed_tasks_df = displayed_tasks_df[ |
|
displayed_tasks_df["area"].isin( |
|
displayed_tasks_df["area"].value_counts().head(n_areas).index |
|
) |
|
] |
|
tasks_fig = px.scatter( |
|
displayed_tasks_df, |
|
color="area", |
|
hover_data=["name"], |
|
opacity=0.7, |
|
**self.plot_kwargs, |
|
) |
|
print("N DISPLAYED TASKS", len(displayed_tasks_df)) |
|
return tasks_fig |
|
|
|
class Config: |
|
arbitrary_types_allowed = True |
|
|