Spaces:
Paused
Paused
import os | |
import torch | |
import numpy as np | |
import pandas as pd | |
from sentence_transformers import util, SentenceTransformer | |
import redis | |
import json | |
from typing import Dict, List | |
import google.generativeai as genai | |
from flask import Flask, request, jsonify, Response | |
import requests | |
from io import StringIO | |
from openai import OpenAI | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Redis configuration | |
r = redis.Redis( | |
host='redis-12878.c1.ap-southeast-1-1.ec2.redns.redis-cloud.com', | |
port=12878, | |
db=0, | |
password="qKl6znBvULaveJhkjIjMr7RCwluJjjbH", | |
decode_responses=True | |
) | |
# Device configuration - Use CUDA if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
client = OpenAI() | |
# Load CSV from Google Drive | |
def load_csv_from_drive(): | |
file_id = "1x3tPRumTK3i7zpymeiPIjVztmt_GGr5V" | |
url = f"https://drive.google.com/uc?id={file_id}" | |
response = requests.get(url) | |
csv_content = StringIO(response.text) | |
df = pd.read_csv(csv_content)[['text', 'embeddings']] | |
# Process embeddings | |
df["embeddings"] = df["embeddings"].apply( | |
lambda x: np.fromstring(x.strip("[]"), sep=",", dtype=np.float32) | |
) | |
return df | |
# Load data and initialize models with GPU support | |
text_chunks_and_embedding_df = load_csv_from_drive() | |
pages_and_chunks = text_chunks_and_embedding_df.to_dict(orient="records") | |
embeddings = torch.tensor( | |
np.vstack(text_chunks_and_embedding_df["embeddings"].values), | |
dtype=torch.float32 | |
).to(device) | |
# Initialize embedding model with GPU support | |
embedding_model = SentenceTransformer( | |
model_name_or_path="keepitreal/vietnamese-sbert", | |
device=device | |
) | |
def store_conversation(conversation_id: str, q: str, a: str) -> None: | |
conversation_element = { | |
'q': q, | |
'a': a, | |
} | |
conversation_json = json.dumps(conversation_element) | |
r.lpush(f'conversation_{conversation_id}', conversation_json) | |
current_length = r.llen(f'conversation_{conversation_id}') | |
if current_length > 2: | |
r.rpop(f'conversation_{conversation_id}') | |
def retrieve_conversation(conversation_id): | |
conversation = r.lrange(f'conversation_{conversation_id}', 0, -1) | |
return [json.loads(c) for c in conversation] | |
def combine_vectors_method2(vector_weight_pairs): | |
weight_norm = np.sqrt(sum(weight**2 for _, weight in vector_weight_pairs)) | |
combined_vector = np.zeros_like(vector_weight_pairs[0][0]) | |
for vector, weight in vector_weight_pairs: | |
normalized_weight = weight / weight_norm | |
combined_vector += vector * normalized_weight | |
return combined_vector | |
def get_weighted_query(current_question: str, parsed_conversation: List[Dict]) -> np.ndarray: | |
# Move computation to GPU | |
with torch.cuda.device(device): | |
current_vector = embedding_model.encode(current_question, convert_to_tensor=True) | |
weighted_parts = [(current_vector.cpu().numpy(), 1.0)] | |
if parsed_conversation: | |
context_string = " ".join( | |
f"{chat['q']} {chat['a']}" for chat in parsed_conversation | |
) | |
context_vector = embedding_model.encode(context_string, convert_to_tensor=True) | |
similarity = util.pytorch_cos_sim(current_vector, context_vector)[0][0].item() | |
weight = 1.0 if similarity > 0.4 else 0.5 | |
weighted_parts.append((context_vector.cpu().numpy(), weight)) | |
weighted_query_vector = combine_vectors_method2(weighted_parts) | |
weighted_query_vector = torch.from_numpy(weighted_query_vector).to(device, dtype=torch.float32) | |
# Normalize vector | |
norm = torch.norm(weighted_query_vector) | |
weighted_query_vector = weighted_query_vector / norm if norm > 0 else weighted_query_vector | |
return weighted_query_vector.cpu().numpy() | |
def retrieve_relevant_resources(query_vector, embeddings, similarity_threshold=0.5, n_resources_to_return=10): | |
query_embedding = torch.from_numpy(query_vector).to(device, dtype=torch.float32) | |
if len(query_embedding.shape) == 1: | |
query_embedding = query_embedding.unsqueeze(0) | |
if embeddings.shape[1] != query_embedding.shape[1]: | |
query_embedding = torch.nn.functional.pad( | |
query_embedding, | |
(0, embeddings.shape[1] - query_embedding.shape[1]) | |
) | |
# Normalize tensors on GPU | |
query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1) | |
embeddings_normalized = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
# Perform matmul on GPU | |
cosine_scores = torch.matmul(query_embedding, embeddings_normalized.t())[0] | |
mask = cosine_scores >= similarity_threshold | |
filtered_scores = cosine_scores[mask] | |
filtered_indices = mask.nonzero().squeeze() | |
if len(filtered_scores) == 0: | |
return torch.tensor([], device=device), torch.tensor([], device=device) | |
k = min(n_resources_to_return, len(filtered_scores)) | |
scores, indices = torch.topk(filtered_scores, k=k) | |
final_indices = filtered_indices[indices] | |
return scores, final_indices | |
def hyde(query, conversation_id, cid): | |
prompt = """ | |
[Your existing prompt text here] | |
""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": prompt, | |
} | |
] | |
history = retrieve_conversation(conversation_id) | |
for c in history: | |
messages.append({ | |
"role": "user", | |
"content": c["q"] | |
}) | |
messages.append({ | |
"role": "assistant", | |
"content": c["a"] | |
}) | |
if cid: | |
messages.append({ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": query}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": "https://magenta-known-swan-641.mypinata.cloud/ipfs/" + cid, | |
} | |
}, | |
], | |
}) | |
else: | |
messages.append({ | |
"role": "user", | |
"content": query | |
}) | |
completion = client.chat.completions.create( | |
model="gpt-4o", | |
messages=messages | |
) | |
return completion.choices[0].message.content | |
def prompt_formatter(mode, query: str, context_items: List[Dict], history: List[Dict] = None, isFirst = False) -> str: | |
# [Your existing prompt_formatter implementation] | |
pass | |
def ask_with_history_v3(query: str, conversation_id: str, isFirst, cid, mode): | |
parsed_conversation = retrieve_conversation(conversation_id) | |
weighted_query_vector = get_weighted_query(query, parsed_conversation) | |
threshold = 0.4 | |
scores, indices = retrieve_relevant_resources( | |
query_vector=weighted_query_vector, | |
similarity_threshold=threshold, | |
embeddings=embeddings | |
) | |
# Move results to CPU for processing | |
filtered_pairs = [(score.cpu().item(), idx.cpu().item()) for score, idx in zip(scores, indices) if score.cpu().item() >= threshold] | |
if filtered_pairs: | |
filtered_scores, filtered_indices = zip(*filtered_pairs) | |
context_items = [pages_and_chunks[i] for i in filtered_indices] | |
for i, item in enumerate(context_items): | |
item["score"] = filtered_scores[i] | |
else: | |
context_items = [] | |
prompt = prompt_formatter(mode, query=query, context_items=context_items, history=parsed_conversation, isFirst=isFirst) | |
genai.configure(api_key="AIzaSyDluIEKEhT1Dw2zx7SHEdmKipwBcYOmFQw") | |
model = genai.GenerativeModel("gemini-1.5-flash") | |
response = model.generate_content(prompt, stream=True) | |
for chunk in response: | |
yield chunk.text | |
if mode == "2" or ("Mình sẽ hỗ trợ bạn câu khác nhé?" in response.text): | |
return | |
store_conversation(conversation_id, query, response.text) | |
# API endpoints | |
def home(): | |
return "Hello World" | |
def ping(): | |
return jsonify("Service is running") | |
def generate_response(): | |
query = request.json['query'] | |
conversation_id = request.json['conversation_id'] | |
isFirst = request.json['is_first'] == "true" | |
cid = request.json['cid'] | |
mode = request.json['mode'] | |
hyde_query = hyde(query, conversation_id, cid) | |
if hyde_query[-1] == '.': | |
return Response(hyde_query, mimetype='text/plain') | |
def generate(): | |
for token in ask_with_history_v3(hyde_query, conversation_id, isFirst, cid, mode): | |
yield token | |
return Response(generate(), mimetype='text/plain') | |
if __name__ == '__main__': | |
# Initialize data and models | |
app.run(host="0.0.0.0", port=7860) |