Spaces:
Configuration error
Configuration error
import os | |
import json | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
def get_n_weighted_scores(embeddings, query, n, objective_weight, subjective_weight): | |
query = [model.encode(query)] | |
weighted_scores = [] | |
for key, value in embeddings.items(): | |
objective_embedding = value['objective_embedding'] | |
subjective_embeddings = value['subjective_embeddings'] | |
objective_score = cosine_similarity(query, objective_embedding).item() | |
subjective_scores = cosine_similarity(query, subjective_embeddings) | |
max_score = 0 | |
max_review_index = 0 | |
for idx, score in enumerate(subjective_scores[0].tolist()): | |
weighted_score = ((objective_score * objective_weight)+(score * subjective_weight)) | |
if weighted_score > max_score: | |
max_score = weighted_score | |
max_review_index = idx | |
weighted_scores.append((key, max_score, max_review_index)) | |
return sorted(weighted_scores, key=lambda x: x[1], reverse=True)[:n] | |
def filter_anime(embeddings, genres, themes, rating): | |
genres = set(genres) | |
themes = set(themes) | |
rating = set(rating) | |
filtered_anime = embeddings.copy() | |
for key, anime in embeddings.items(): | |
anime_genres = set(anime['genres']) | |
anime_themes = set(anime['themes']) | |
anime_rating = set([anime['rating']]) | |
if genres.intersection(anime_genres) or 'ALL' in genres: | |
pass | |
else: | |
filtered_anime.pop(key) | |
continue | |
if themes.intersection(anime_themes) or 'ALL' in themes: | |
pass | |
else: | |
filtered_anime.pop(key) | |
continue | |
if rating.intersection(anime_rating) or 'ALL' in rating: | |
pass | |
else: | |
filtered_anime.pop(key) | |
continue | |
return filtered_anime | |
def get_recommendation(query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight): | |
filtered_anime = filter_anime(embeddings, genres, themes, rating) | |
results = [] | |
weighted_scores = get_n_weighted_scores(filtered_anime, query, number_of_recommendations, float(objective_weight), float(subjective_weight)) | |
for idx, (key, score, review_index) in enumerate(weighted_scores, start=1): | |
data = embeddings[key] | |
if not data['english']: | |
name = data['japanese'] | |
else: | |
name = data['english'] | |
english = data['english'] | |
description = data['description'] | |
review = data['reviews'][review_index]['text'] | |
image = data['image'] | |
results.append(gr.Image(label=f"Recommendation {idx}: {name}",value=image, height=435, width=500, visible=True)) | |
results.append(gr.Textbox(label=f"Synopsis", value=description, max_lines=7, visible=True)) | |
results.append(gr.Textbox(label=f"Most Relevant User Review",value=review, max_lines=7, visible=True)) | |
for _ in range(10-number_of_recommendations): | |
results.append(gr.Image(visible=False)) | |
results.append(gr.Textbox(visible=False)) | |
results.append(gr.Textbox(visible=False)) | |
return results | |
if __name__ == '__main__': | |
with open('./embeddings/data.json') as f: | |
data = json.load(f) | |
embeddings = data['embeddings'] | |
filters = data['filters'] | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue='red')) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
''' | |
# Welcome to the Nuanced Recommendation System! | |
### This system **combines** both objective (synopsis, episode count, themes) and subjective (user reviews) data, in order to recommend the most approprate anime. Feel free to refine using the **optional** filters below! | |
''' | |
) | |
with gr.Column(): | |
pass | |
with gr.Row(): | |
with gr.Column() as input_col: | |
query = gr.Textbox(label="What are you looking for?") | |
number_of_recommendations = gr.Slider(label= "# of Recommendations", minimum=1, maximum=10, value=3, step=1) | |
genres = gr.Dropdown(label='Genres',multiselect=True,choices=filters['genres'], value=['ALL']) | |
themes = gr.Dropdown(label='Themes',multiselect=True,choices=filters['themes'], value=['ALL']) | |
rating = gr.Dropdown(label='Rating',multiselect=True,choices=filters['rating'], value=['ALL']) | |
objective_weight = gr.Slider(label= "Objective Weight", minimum=0, maximum=1, value=.5, step=.1) | |
subjective_weight = gr.Slider(label= "Subjective Weight", minimum=0, maximum=1, value=.5, step=.1) | |
submit_btn = gr.Button("Submit") | |
examples = gr.Examples( | |
examples=[ | |
['A sci-fi anime set in a future where AI and robots have become self-aware', 3, ['Action', 'Sci-Fi', 'Fantasy'], ['ALL'], ['PG-13 - Teens 13 or older'], .8, .2], | |
['An anime where a group of students form a band, and the story focuses on their personal growth and struggles with adulthood', 5, ['ALL'], ['Music'], ['PG-13 - Teens 13 or older', 'R - 17+ (violence & profanity)'], .3, .7], | |
['An anime where the main character starts as a villain but slowly redeems themselves', 3, ['Suspense', 'Action'], ['ALL'], ['PG-13 - Teens 13 or older', 'R - 17+ (violence & profanity)'], .2, .8], | |
], | |
inputs=[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight], | |
) | |
outputs = [] | |
with gr.Column(): | |
for i in range(10): | |
with gr.Row(): | |
with gr.Column(): | |
outputs.append(gr.Image(height=435, width=500, visible=False)) | |
with gr.Column(): | |
outputs.append(gr.Textbox(max_lines=7, visible=False)) | |
outputs.append(gr.Textbox(max_lines=7, visible=False)) | |
submit_btn.click( | |
get_recommendation, | |
[query, number_of_recommendations, genres, themes, rating, objective_weight, subjective_weight], | |
outputs | |
) | |
demo.launch() | |