from flask import Flask, request, jsonify import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import HeteroData import numpy as np import pandas as pd import networkx as nx from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report, roc_curve from sklearn.model_selection import train_test_split from pathlib import Path from datetime import datetime from loguru import logger from huggingface_hub import hf_hub_download import json from preprocessing_test import Preprocessor from src.model import * from main import start_pipelines app = Flask(__name__) # Define default values for each column default_values = { 'review_id': 'KU_O5udG6zpxOg-VcAEodg', 'user_id': 'mh_-eMZ6K5RLWhZyISBhwA', 'business_id': 'XQfwVwDr-v0ZS3_CbbE5Xw', 'review_stars': 0, 'review_useful': 0, 'review_funny': 0, 'review_cool': 0, 'review_text': 'It was a moderate experience', 'review_date': 1531001351000, 'business_name': 'Coffe at LA', 'address': '1460 LA', 'city': 'LA', 'state': 'CA', 'postal_code': '00000', 'latitude': 0.0, 'longitude': 0.0, 'business_stars': 0.0, 'business_review_count': 0, 'is_open': 0, 'attributes': '{}', 'categories': 'Restaurants', 'hours': '{"Monday": "7:0-20:0", "Tuesday": "7:0-20:0", "Wednesday": "7:0-20:0", "Thursday": "7:0-20:0", "Friday": "7:0-21:0", "Saturday": "7:0-21:0", "Sunday": "7:0-21:0"}', 'user_name': 'default_user', 'user_review_count': 0, 'yelping_since': '2023-01-01 00:00:00', 'user_useful': 0, 'user_funny': 0, 'user_cool': 0, 'elite': '2024,2025', 'friends': '', 'fans': 0, 'average_stars': 0.0, 'compliment_hot': 0, 'compliment_more': 0, 'compliment_profile': 0, 'compliment_cute': 0, 'compliment_list': 0, 'compliment_note': 0, 'compliment_plain': 0, 'compliment_cool': 0, 'compliment_funny': 0, 'compliment_writer': 0, 'compliment_photos': 0, 'checkin_date': '2023-01-01 00:00:00', 'tip_compliment_count': 0.0, 'tip_count': 0.0 } # Expected types for validation expected_types = { 'review_id': str, 'user_id': str, 'business_id': str, 'review_stars': int, 'review_useful': int, 'review_funny': int, 'review_cool': int, 'review_text': str, 'review_date': int, 'business_name': str, 'address': str, 'city': str, 'state': str, 'postal_code': str, 'latitude': float, 'longitude': float, 'business_stars': float, 'business_review_count': int, 'is_open': int, 'attributes': dict, # Assuming string representation of dict 'categories': str, 'hours': dict, # Assuming string representation of dict 'user_name': str, 'user_review_count': int, 'yelping_since': str, 'user_useful': int, 'user_funny': int, 'user_cool': int, 'elite': str, 'friends': str, 'fans': int, 'average_stars': float, 'compliment_hot': int, 'compliment_more': int, 'compliment_profile': int, 'compliment_cute': int, 'compliment_list': int, 'compliment_note': int, 'compliment_plain': int, 'compliment_cool': int, 'compliment_funny': int, 'compliment_writer': int, 'compliment_photos': int, 'checkin_date': str, 'tip_compliment_count': float, 'tip_count': float } @app.route('/predict', methods=['POST']) def predict(): try: # Check if request contains JSON data if not request.json: return jsonify({'error': 'Request must contain JSON data'}), 400 data = request.json # Extract train, test, and train_size with defaults train = data.get('train', False) test = data.get('test', False) train_size = float(data.get('train_size', 0.1)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Handle training mode if train in (True, 'true', 'True'): start_pipelines(train_size=train_size) logger.info("PIPELINES FINISHED SUCCESSFULLY") return jsonify({ 'message': 'Training pipelines executed successfully', 'train_size': train_size }), 200 # Handle testing/inference mode elif test in (True, 'test', 'True'): REPO_ID = "Askhedi/graphformermodel" MODEL_FILENAME = "model_GraphformerModel_latest.pth" model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) # Load model model = HeteroGraphormer(hidden_dim=64, output_dim=1, edge_dim=4).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() # Process input data from JSON row = {} warnings = [] for col, expected_type in expected_types.items(): value = data.get(col, default_values[col]) try: if value == "" or value is None: row[col] = default_values[col] elif col in ['attributes', 'hours']: # Expect a valid JSON string that parses to a dict if isinstance(value, str): parsed = json.loads(value) if not isinstance(parsed, dict): raise ValueError row[col] = value # Keep as string for Preprocessor else: raise ValueError else: row[col] = expected_type(value) except (ValueError, TypeError, json.JSONDecodeError): row[col] = default_values[col] warnings.append(f"Invalid input for '{col}' (expected {expected_type.__name__}), using default value: {default_values[col]}") # Convert dictionaries to strings before passing to DataFrame for col in ['attributes', 'hours']: if isinstance(row[col], dict): row[col] = json.dumps(row[col]) # Create DataFrame from input input_df = pd.DataFrame([row]) # Preprocess using Preprocessor preprocessor = Preprocessor(input_df) processed_df = preprocessor.run_pipeline() logger.info(f"PREPROCESSING COMPLETED VALUES ARE {processed_df}") # Build standalone graph from processed data num_users = 1 num_businesses = 1 num_rows = 1 graph = HeteroData() features = torch.tensor(processed_df.drop(columns=['user_id', 'review_id', 'business_id']).values, dtype=torch.float, device=device) time_since_user = torch.tensor(processed_df['time_since_last_review_user'].values, dtype=torch.float, device=device) time_since_business = torch.tensor(processed_df['time_since_last_review_business'].values, dtype=torch.float, device=device) user_indices = torch.tensor([0], dtype=torch.long, device=device) business_indices = torch.tensor([0], dtype=torch.long, device=device) review_indices = torch.tensor([0], dtype=torch.long, device=device) user_feats = torch.zeros(num_users, 14, device=device) business_feats = torch.zeros(num_businesses, 8, device=device) review_feats = torch.zeros(num_rows, 16, device=device) user_feats[0] = features[0, :14] business_feats[0] = features[0, 14:22] review_feats[0] = features[0, 22:38] graph['user'].x = user_feats graph['business'].x = business_feats graph['review'].x = review_feats graph['user', 'writes', 'review'].edge_index = torch.stack([user_indices, review_indices], dim=0) graph['review', 'about', 'business'].edge_index = torch.stack([review_indices, business_indices], dim=0) # Compute encodings G = nx.DiGraph() node_type_map = {0: 'user', 1: 'business', 2: 'review'} G.add_nodes_from([0, 1, 2]) G.add_edge(0, 2) # user -> review G.add_edge(2, 1) # review -> business num_nodes = 3 spatial_encoding = torch.full((num_nodes, num_nodes), float('inf'), device=device) for i in range(num_nodes): for j in range(num_nodes): if i == j: spatial_encoding[i, j] = 0 elif nx.has_path(G, i, j): spatial_encoding[i, j] = nx.shortest_path_length(G, i, j) centrality_encoding = torch.tensor([G.degree(i) for i in range(num_nodes)], dtype=torch.float, device=device).view(-1, 1) edge_features_dict = {} user_writes_edge = graph['user', 'writes', 'review'].edge_index review_about_edge = graph['review', 'about', 'business'].edge_index edge_features_dict[('user', 'writes', 'review')] = create_temporal_edge_features( time_since_user[user_writes_edge[0]], time_since_user[user_writes_edge[1]], user_indices[user_writes_edge[0]], user_indices[user_writes_edge[0]] ) edge_features_dict[('review', 'about', 'business')] = create_temporal_edge_features( time_since_business[review_about_edge[0]], time_since_business[review_about_edge[1]], torch.zeros_like(review_about_edge[0]), torch.zeros_like(review_about_edge[0]) ) time_since_dict = { 'user': torch.tensor([time_since_user[0]], dtype=torch.float, device=device) } # Inference with torch.no_grad(): out = model(graph, spatial_encoding, centrality_encoding, node_type_map, time_since_dict, edge_features_dict) pred_label = 1 if out.squeeze().item() > 0.5 else 0 prob = out.squeeze().item() # Combine warnings and result result = { 'warnings': warnings, 'prediction': 'Fake' if pred_label == 1 else 'Not Fake', 'probability': float(prob) } return jsonify(result), 200 else: return jsonify({ 'error': 'Either "train" or "test" must be set to true' }), 400 except Exception as e: return jsonify({'error': str(e)}), 500