Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Header | |
from pydantic import BaseModel | |
import os | |
from pymongo import MongoClient | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
import uvicorn | |
from dotenv import load_dotenv | |
from fastapi.middleware.cors import CORSMiddleware | |
from uuid import uuid4 | |
import joblib | |
import librosa | |
import numpy as np | |
import pandas as pd | |
import numpy as np | |
import librosa.display | |
import soundfile as sf | |
import opensmile | |
load_dotenv() | |
# MongoDB connection | |
MONGODB_ATLAS_CLUSTER_URI = os.getenv("MONGODB_ATLAS_CLUSTER_URI", None) | |
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI) | |
DB_NAME = "quran_db" | |
COLLECTION_NAME = "tafsir" | |
ATLAS_VECTOR_SEARCH_INDEX_NAME = "langchain_index" | |
MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME] | |
embeddings = SentenceTransformerEmbeddings(model_name="BAAI/bge-m3") | |
vector_search = MongoDBAtlasVectorSearch.from_connection_string( | |
MONGODB_ATLAS_CLUSTER_URI, | |
DB_NAME + "." + COLLECTION_NAME, | |
embeddings, | |
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME, | |
) | |
# FastAPI application setup | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def index_file(filepath): | |
""" Index each block in a file separated by double newlines for quick search. | |
Returns a dictionary with key as content and value as block number. """ | |
index = {} | |
with open(filepath, 'r', encoding='utf-8') as file: | |
content = file.read() # Read the whole file at once | |
blocks = content.split("\n\n") # Split the content by double newlines | |
for block_number, block in enumerate(blocks, 1): # Starting block numbers at 1 for human readability | |
# Replace single newlines within blocks with space and strip leading/trailing whitespace | |
formatted_block = ' '.join(block.split('\n')).strip() | |
index[formatted_block] = block_number | |
# if(block_number == 100): | |
# print(formatted_block) # Print the 5th block | |
return index | |
def get_text_by_block_number(filepath, block_numbers): | |
""" Retrieve specific blocks from a file based on block numbers, where each block is separated by '\n\n'. """ | |
blocks_text = [] | |
with open(filepath, 'r', encoding='utf-8') as file: | |
content = file.read() # Read the whole file at once | |
blocks = content.split("\n\n") # Split the content by double newlines | |
for block_number, block in enumerate(blocks, 1): # Starting block numbers at 1 for human readability | |
if block_number in block_numbers: | |
# Replace single newlines within blocks with space and strip leading/trailing whitespace | |
formatted_block = ' '.join(block.split('\n')).strip() | |
blocks_text.append(formatted_block) | |
if len(blocks_text) == len(block_numbers): # Stop reading once all required blocks are retrieved | |
break | |
return blocks_text | |
# Existing API endpoints | |
async def read_root(): | |
return {"message": "Welcome to our app"} | |
# New Query model for the POST request body | |
class Item(BaseModel): | |
question: str | |
EXPECTED_TOKEN = os.getenv("API_TOKEN") | |
def verify_token(authorization: str = Header(None)): | |
""" | |
Dependency to verify the Authorization header contains the correct Bearer token. | |
""" | |
# Prefix for bearer token in the Authorization header | |
prefix = "Bearer " | |
# Check if the Authorization header is present and correctly formatted | |
if not authorization or not authorization.startswith(prefix): | |
raise HTTPException(status_code=401, detail="Unauthorized: Missing or invalid token") | |
# Extract the token from the Authorization header | |
token = authorization[len(prefix):] | |
# Compare the extracted token to the expected token value | |
if token != EXPECTED_TOKEN: | |
raise HTTPException(status_code=401, detail="Unauthorized: Incorrect token") | |
# New API endpoint to get an answer using the chain | |
async def get_answer(item: Item, token: str = Depends(verify_token)): | |
try: | |
# Perform the similarity search with the provided question | |
matching_docs = vector_search.similarity_search(item.question, k=3) | |
clean_answers = [doc.page_content.replace("\n", " ").strip() for doc in matching_docs] | |
# Assuming 'search_file.txt' is where we want to search answers | |
answers_index = index_file('app/quran_tafseer_formatted.txt') | |
# Collect line numbers based on answers found | |
line_numbers = [answers_index[answer] for answer in clean_answers if answer in answers_index] | |
# Assuming 'retrieve_file.txt' is where we retrieve lines based on line numbers | |
result_text = get_text_by_block_number('app/quran_tafseer.txt', line_numbers) | |
return {"result_text": result_text} | |
except Exception as e: | |
# If there's an error, return a 500 error with the error's details | |
raise HTTPException(status_code=500, detail=str(e)) | |
# mlp | |
mlp_model = joblib.load('app/mlp_model.pkl') | |
mlp_pca = joblib.load('app/pca.pkl') | |
mlp_scaler = joblib.load('app/scaler.pkl') | |
mlp_label_encoder = joblib.load('app/label_encoder.pkl') | |
def preprocess_audio(path, save_dir): | |
y, sr = librosa.load(path) | |
# remove silence | |
intervals = librosa.effects.split(y, top_db=20) | |
# Concatenate non-silent intervals | |
y_no_gaps = np.concatenate([y[start:end] for start, end in intervals]) | |
file_name_without_extension = os.path.basename(path).split('.')[0] | |
extension = os.path.basename(path).split('.')[1] | |
y_trimmed, _ = librosa.effects.trim(y_no_gaps, top_db = 20) | |
D = librosa.stft(y) | |
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) | |
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128*2,) | |
S_db_mel = librosa.amplitude_to_db(np.abs(S), ref=np.max) | |
# Apply noise reduction (example using spectral subtraction) | |
y_denoised = librosa.effects.preemphasis(y_trimmed) | |
# Apply dynamic range compression | |
y_compressed = librosa.effects.preemphasis(y_denoised) | |
# Augmentation (example of time stretching) | |
# y_stretched = librosa.effects.time_stretch(y_compressed, rate=1.2) | |
# Silence Removal | |
y_silence_removed, _ = librosa.effects.trim(y_compressed) | |
# Equalization (example: apply high-pass filter) | |
y_equalized = librosa.effects.preemphasis(y_silence_removed) | |
# Define target sample rate | |
target_sr = sr | |
# # Data Augmentation (example: pitch shifting) | |
# y_pitch_shifted = librosa.effects.pitch_shift(y_normalized, sr=target_sr, n_steps=2) | |
# Split audio into non-silent intervals | |
# Normalize the audio signal | |
y_normalized = librosa.util.normalize(y_equalized) | |
# Feature Extraction (example: MFCCs) | |
# mfccs = librosa.feature.mfcc(y=y_normalized, sr=target_sr, n_mfcc=20) | |
# output_file_path = os.path.join(save_dir, f"{file_name_without_extension}.{extension}") | |
# Write the audio data to the output file in .wav format | |
sf.write(path, y_normalized, target_sr) | |
return 'success' | |
smile = opensmile.Smile( | |
feature_set=opensmile.FeatureSet.ComParE_2016, | |
feature_level=opensmile.FeatureLevel.Functionals, | |
) | |
def extract_features(file_path): | |
# # Load the audio file | |
# y, sr = librosa.load(file_path, sr=None, dtype=np.float32) | |
# # Extract MFCCs | |
# mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20) | |
# mfccs_mean = pd.Series(mfccs.mean(axis=1), index=[f'mfcc_{i}' for i in range(mfccs.shape[0])]) | |
# # Extract Spectral Features | |
# spectral_centroids = pd.Series(np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)), index=['spectral_centroid']) | |
# spectral_rolloff = pd.Series(np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr)), index=['spectral_rolloff']) | |
# spectral_flux = pd.Series(np.mean(librosa.onset.onset_strength(y=y, sr=sr)), index=['spectral_flux']) | |
# spectral_contrast = pd.Series(np.mean(librosa.feature.spectral_contrast(S=np.abs(librosa.stft(y)), sr=sr), axis=1), index=[f'spectral_contrast_{i}' for i in range(librosa.feature.spectral_contrast(S=np.abs(librosa.stft(y)), sr=sr).shape[0])]) | |
# # Extract Pitch | |
# pitches, magnitudes = librosa.piptrack(y=y, sr=sr) | |
# pitch_mean = pd.Series(np.mean(pitches[pitches != 0]), index=['pitch_mean']) # Average only non-zero values | |
# # Extract Zero Crossings | |
# zero_crossings = pd.Series(np.mean(librosa.feature.zero_crossing_rate(y)), index=['zero_crossings']) | |
# # Combine all features into a single Series | |
# features = pd.concat([mfccs_mean, spectral_centroids, spectral_rolloff, spectral_flux, spectral_contrast, pitch_mean, zero_crossings]) | |
features = smile.process_file(file_path) | |
features_reshaped = features.squeeze() | |
# Ensure it's now a 2D structure suitable for DataFrame | |
print("New shape of features:", features_reshaped.shape) | |
all_data = pd.DataFrame([features_reshaped]) | |
return all_data | |
async def handle_audio(file: UploadFile = File(...)): | |
try: | |
# Ensure that we are handling an MP3 file | |
if file.content_type == "audio/mpeg" or file.content_type == "audio/mp3": | |
file_extension = ".mp3" | |
elif file.content_type == "audio/wav": | |
file_extension = ".wav" | |
else: | |
raise HTTPException(status_code=400, detail="Invalid file type. Supported types: MP3, WAV.") | |
# Read the file's content | |
contents = await file.read() | |
temp_filename = f"app/{uuid4().hex}{file_extension}" | |
# Save file to a temporary file if needed or process directly from memory | |
with open(temp_filename, "wb") as f: | |
f.write(contents) | |
preprocess_audio(temp_filename, 'app') | |
# Here you would add the feature extraction logic | |
features = extract_features(temp_filename) | |
print("Extracted Features:", features) | |
features = mlp_scaler.transform(features) | |
features = mlp_pca.transform(features) | |
# proceed with an inference | |
results = mlp_model.predict(features) | |
decoded_predictions = [mlp_label_encoder.classes_[i] for i in results] | |
# # Decode the predictions using the label encoder | |
# decoded_predictions = mlp_label_encoder.inverse_transform(results) | |
# .tolist() | |
# Clean up the temporary file | |
os.remove(temp_filename) | |
# Return a successful response with decoded predictions | |
return {"message": "File processed successfully", "prediction": decoded_predictions} | |
except Exception as e: | |
print(e) | |
# Handle possible exceptions | |
raise HTTPException(status_code=500, detail=str(e)) | |
# if __name__ == "__main__": | |
# uvicorn.run("main:app", host="0.0.0.0", port=8080, reload=False) | |