from fastapi import FastAPI

app = FastAPI()
from fastapi.middleware.cors import CORSMiddleware
origins = [
"*"
]
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
def greet_json():
    return {"Hello": "World!"}


#--------------------------------------------------------------------------------------------------------------------

import os
import gdown

file_id = "1zhisRgRi2qBFX73VFhzh-Ho93MORQqVa" 
output_dir = "./downloads"  
output_file = "file.h5"  

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

output_path = os.path.join(output_dir, output_file)

url = f"https://drive.google.com/uc?id={file_id}"

try:
    gdown.download(url, output_path, quiet=False)
    print(f"File downloaded successfully to: {output_path}")
except Exception as e:
    print(f"Error downloading file: {e}")

output_file = "file.h5"  
file_path = os.path.join(output_dir, output_file)


#--------------------------------------------------------------------------------------------------------------------

file_id = "1wIaycDFGTF3e0PpAHKk-GLnxk4cMehOU" 
output_dir = "./downloads"  
output_file = "file2.h5"  

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

output_path = os.path.join(output_dir, output_file)

url = f"https://drive.google.com/uc?id={file_id}"

try:
    gdown.download(url, output_path, quiet=False)
    print(f"File downloaded successfully to: {output_path}")
except Exception as e:
    print(f"Error downloading file: {e}")

output_file = "file2.h5"  
file_path = os.path.join(output_dir, output_file)


if os.path.exists(file_path):
    print(f"The file '{output_file}' exists at '{file_path}'.")
else:
    print(f"The file '{output_file}' does not exist at '{file_path}'.")

#--------------------------------------------------------------------------------------------------------------------
import os
import numpy as np
import tensorflow as tf
import tensorflow
import librosa
import matplotlib.pyplot as plt
# import gradio as gr

import os
os.environ["TORCH_HOME"] = "/tmp/torch_cache"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"
os.environ["FONTCONFIG_PATH"] = "/tmp/fontconfig"
os.environ["HF_HOME"] = "/tmp/huggingface_cache"

from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.optimizers import Adam
from transformers import pipeline

class UnifiedDeepfakeDetector:
    def __init__(self):
        self.input_shape = (224, 224, 3)
        self.vgg_model = self.build_vgg16_model()
        self.dense_model = tf.keras.models.load_model('downloads/file2.h5')
        self.cnn_model = tf.keras.models.load_model('downloads/file.h5')
        self.melody_machine = pipeline(model="MelodyMachine/Deepfake-audio-detection-V2")

    def build_vgg16_model(self):
        base_model = VGG16(weights='imagenet', include_top=False, input_shape=self.input_shape)
        for layer in base_model.layers:
            layer.trainable = False

        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        x = Dense(512, activation='relu')(x)
        x = Dropout(0.5)(x)
        x = Dense(256, activation='relu')(x)
        x = Dropout(0.3)(x)
        output = Dense(1, activation='sigmoid')(x)

        model = Model(inputs=base_model.input, outputs=output)
        model.compile(optimizer=Adam(learning_rate=0.0001),
                     loss='binary_crossentropy',
                     metrics=['accuracy'])
        return model

    def audio_to_spectrogram(self, file_path, plot=False):
        try:
            audio, sr = librosa.load(file_path, duration=5.0, sr=22050)
            spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=224, fmax=8000)
            spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max)

            if plot:
                plt.figure(figsize=(12, 6))
                librosa.display.specshow(spectrogram_db, y_axis='mel', x_axis='time', cmap='viridis')
                plt.colorbar(format='%+2.0f dB')
                plt.title('Mel Spectrogram Analysis')
                plot_path = 'spectrogram_plot.png'
                plt.savefig(plot_path, dpi=300, bbox_inches='tight')
                plt.close()
                return plot_path

            spectrogram_norm = (spectrogram_db - spectrogram_db.min()) / (spectrogram_db.max() - spectrogram_db.min())
            spectrogram_rgb = np.stack([spectrogram_norm]*3, axis=-1)
            spectrogram_resized = tf.image.resize(spectrogram_rgb, (224, 224))
            return preprocess_input(spectrogram_resized * 255)

        except Exception as e:
            print(f"Spectrogram error: {e}")
            return None

    def analyze_audio_rf(self, audio_path, model_choice="all"):
        results = {}
        plots = {}
        r = []
        audio_features = {}

        try:
            # Load audio and extract basic features
            audio, sr = librosa.load(audio_path, res_type="kaiser_fast")
            audio_features = {
                "sample_rate": sr,
                "duration": librosa.get_duration(y=audio, sr=sr),
                "rms_energy": float(np.mean(librosa.feature.rms(y=audio))),
                "zero_crossing_rate": float(np.mean(librosa.feature.zero_crossing_rate(y=audio)))
            }

            # VGG16 Analysis
            if model_choice in ["VGG16", "all"]:
                spec = self.audio_to_spectrogram(audio_path)
                if spec is not None:
                    pred = self.vgg_model.predict(np.expand_dims(spec, axis=0))[0][0]
                    results["VGG16"] = {
                        "prediction": "FAKE" if pred > 0.5 else "REAL",
                        "confidence": float(pred if pred > 0.5 else 1 - pred),
                        "raw_score": float(pred)
                    }
                    plots["spectrogram"] = self.audio_to_spectrogram(audio_path, plot=True)
                    r.append("FAKE" if pred > 0.5 else "REAL")

            # Dense Model Analysis
            if model_choice in ["Dense", "all"]:
                mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=40)
                mfcc_scaled = np.mean(mfcc.T, axis=0).reshape(1, -1)
                pred = self.dense_model.predict(mfcc_scaled)
                results["Dense"] = {
                    "prediction": "FAKE" if np.argmax(pred[0]) == 0 else "REAL",
                    "confidence": float(np.max(pred[0])),
                    "raw_scores": pred[0].tolist()
                }
                r.append("FAKE" if np.argmax(pred[0]) == 0 else "REAL")

            # CNN Model Analysis
            if model_choice in ["CNN", "all"]:
                mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=40)
                mfcc_scaled = np.mean(mfcc.T, axis=0).reshape(None, 40, 1, 1)
                pred = self.cnn_model.predict(mfcc_scaled)
                results["CNN"] = {
                    "prediction": "FAKE" if np.argmax(pred[0]) == 0 else "REAL",
                    "confidence": float(np.max(pred[0])),
                    "raw_scores": pred[0].tolist()
                }
                r.append("FAKE" if np.argmax(pred[0]) == 0 else "REAL")

            # Melody Machine Analysis
            if model_choice in ["MelodyMachine", "all"]:
                result = self.melody_machine(audio_path)
                best_pred = max(result, key=lambda x: x['score'])
                results["MelodyMachine"] = {
                    "prediction": best_pred['label'].upper(),
                    "confidence": float(best_pred['score']),
                    "all_predictions": result
                }
                r.append(best_pred['label'].upper())

            return r

        except Exception as e:
            print(f"Analysis error: {e}")
            return None, None, None

#--------------------------------------------------------------------------------------------------------------------

import torchaudio
import torch
import numpy as np
from scipy.stats import skew, kurtosis, median_abs_deviation
import os
import torch.nn.functional as F


import os
os.environ["TORCH_HOME"] = "/tmp/torch_cache"



from torchaudio.pipelines import WAV2VEC2_BASE
bundle = WAV2VEC2_BASE

model = bundle.get_model()
print("Model downloaded successfully!")


def extract_features(file_path):
    if os.path.exists(file_path):
        print(f"File successfully written: {file_path}")
    else:
        print("File writing failed.")
    waveform, sample_rate = torchaudio.load(file_path)
    if sample_rate != bundle.sample_rate:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=bundle.sample_rate)(waveform)

    with torch.inference_mode():
        features, _ = model.extract_features(waveform)

    pooled_features = []
    for f in features:
        if f.dim() == 3:
            f = f.permute(0, 2, 1)
            pooled_f = F.adaptive_avg_pool1d(f[0].unsqueeze(0), 1).squeeze(0)
            pooled_features.append(pooled_f)

    final_features = torch.cat(pooled_features, dim=0).numpy()
    final_features = (final_features - np.mean(final_features)) / (np.std(final_features) + 1e-10)

    return final_features

def additional_features(features):
    mad = median_abs_deviation(features)
    features_clipped = np.clip(features, 1e-10, None)
    entropy = -np.sum(features_clipped * np.log(features_clipped))
    return mad, entropy

def classify_audio(features):

    _, entropy = additional_features(features)
    print(entropy)

    if  entropy > 150:
        return True, entropy
    else:
        return False, entropy

#--------------------------------------------------------------------------------------------------------------------
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
import torch
from scipy.stats import skew, kurtosis, median_abs_deviation
import shutil
import subprocess
import os
import librosa


os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["FONTCONFIG_PATH"] = "/tmp/fontconfig"
os.environ["TF_ENABLE_ONEDNN_OPTS"]="0"
os.environ["HF_HOME"] = "/tmp/huggingface_cache"

os.makedirs("/tmp/matplotlib", exist_ok=True)
os.makedirs("/tmp/fontconfig", exist_ok=True)
os.makedirs("/tmp/huggingface_cache", exist_ok=True)

SAVE_DIR = './audio' 
os.makedirs(SAVE_DIR, exist_ok=True)

os.system('apt-get update && apt-get install -y ffmpeg')


def reencode_audio(input_path, output_path):
    command = [
    '/usr/bin/ffmpeg', '-i', input_path, '-acodec', 'pcm_s16le', '-ar', '16000', '-ac', '1', output_path
]
    subprocess.run(command, check=True)

#--------------------------------------------------------------------------------------------------------------------
from collections import Counter
from datetime import datetime
import base64

@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
    print(f"Received file: {file.filename}")

    original_filename = file.filename.rsplit('.', 1)[0]
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    wav_filename = os.path.join(SAVE_DIR, f"{timestamp}.wav")
    reencoded_filename = os.path.join(SAVE_DIR, f"{timestamp}_reencoded.wav")

    # os.makedirs(SAVE_DIR, exist_ok=True)
    with open(wav_filename, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    reencode_audio(wav_filename, reencoded_filename)
    os.remove(wav_filename)
    print(f"File successfully re-encoded as: {reencoded_filename}")

    try:
        audio, sr = librosa.load(reencoded_filename, sr=None)  
        print("Loaded successfully with librosa")
    except Exception as e:
        print(f"Error loading re-encoded file: {e}")
    new_features = extract_features(reencoded_filename)
    prediction, entropy = classify_audio(new_features)
    with open(reencoded_filename, "rb") as audio_file:
        audio_data = audio_file.read()

    # audio_base64 = base64.b64encode(audio_data).decode('utf-8')
    os.remove(reencoded_filename)
    return JSONResponse(content={
        "prediction": bool(prediction),
        "entropy": float(entropy),
    })
    

@app.post("/upload_audio")
async def upload_file(file: UploadFile = File(...)):
    print(f"Received file: {file.filename}")

    original_filename = file.filename.rsplit('.', 1)[0]
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    wav_filename = os.path.join(SAVE_DIR, f"{timestamp}.wav")
    reencoded_filename = os.path.join(SAVE_DIR, f"{timestamp}_reencoded.wav")

    # os.makedirs(SAVE_DIR, exist_ok=True)
    with open(wav_filename, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    reencode_audio(wav_filename, reencoded_filename)
    
    os.remove(wav_filename)
    print(f"File successfully re-encoded as: {reencoded_filename}")

    try:
        audio, sr = librosa.load(reencoded_filename, sr=None)  
        print("Loaded successfully with librosa")
    except Exception as e:
        print(f"Error loading re-encoded file: {e}")
    new_features = extract_features(reencoded_filename)
    detector = UnifiedDeepfakeDetector()
    print(reencoded_filename)
    result = detector.analyze_audio_rf(reencoded_filename, model_choice="all")
    prediction, entropy = classify_audio(new_features)
    with open(reencoded_filename, "rb") as audio_file:
        audio_data = audio_file.read()
    result = list(result)
    result.append("FAKE" if float(entropy) < 150 else "REAL")
    print(result)
    r_normalized = [x.upper() for x in result if x is not None]
    counter = Counter(r_normalized)

    most_common_element, _ = counter.most_common(1)[0]

    print(f"The most frequent element is: {most_common_element}") 
    

    audio_base64 = base64.b64encode(audio_data).decode('utf-8')
    print(f"Audio Data Length: {len(audio_data)}")

    os.remove(reencoded_filename)
    return JSONResponse(content={
        "filename": file.filename,
        "prediction": most_common_element.upper(),
        "entropy": float(entropy),
        "audio": audio_base64,
        "content_type": "audio/wav"
    })