MojoHz commited on
Commit
bf01e94
·
verified ·
1 Parent(s): ad2f0ba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.feature_extraction.text import TfidfVectorizer
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+ from transformers import pipeline, HfApi, HfFolder
6
+ import gradio as gr
7
+
8
+ # Hugging Face login - Ensure that you are logged into Hugging Face before running this space
9
+ HfFolder.save_token("<your_hugging_face_token>") # This is set once to avoid token exposure in the code
10
+
11
+ # Load and preprocess the data
12
+ def preprocess_data(file_path):
13
+ """Load and preprocess the CSV data."""
14
+ data = pd.read_csv(file_path)
15
+
16
+ # Clean column names
17
+ data.columns = data.columns.str.strip().str.replace('#', 'Count').str.replace(' ', '_')
18
+
19
+ # Handle missing values (if any)
20
+ data.fillna(0, inplace=True)
21
+
22
+ return data
23
+
24
+ # Convert data into a retrievable knowledge base
25
+ def create_knowledge_base(data):
26
+ """Transform the data into a knowledge base suitable for retrieval."""
27
+ # Combine relevant fields into a single text-based feature for embedding
28
+ data['Knowledge_Text'] = data.apply(lambda row: (
29
+ f"Player: {row['Player_Name']}, Position: {row['Main_Possition']}, "
30
+ f"Date: {row['Date']}, Session: {row['Session_Name']}, "
31
+ f"Played Time: {row['Played_Time_(min)']} minutes, Top Speed: {row['Top_Speed_(km/h)']} km/h, "
32
+ f"Distance Covered: {row['Dist._Covered_(m)']} meters, "
33
+ f"Intensity: {row['Session_Intensity']}, "
34
+ f"RPE: {row['RPE']}, s-RPE: {row['s-RPE']}"
35
+ ), axis=1)
36
+
37
+ return data[['Player_ID', 'Knowledge_Text']]
38
+
39
+ # Create a similarity-based retrieval function
40
+ def query_knowledge_base(knowledge_base, query, vectorizer):
41
+ """Query the knowledge base using cosine similarity."""
42
+ query_vec = vectorizer.transform([query])
43
+ knowledge_vec = vectorizer.transform(knowledge_base['Knowledge_Text'])
44
+
45
+ # Compute cosine similarities
46
+ similarities = cosine_similarity(query_vec, knowledge_vec).flatten()
47
+
48
+ # Retrieve the most relevant rows
49
+ top_indices = np.argsort(similarities)[::-1][:5] # Top 5 results
50
+ return knowledge_base.iloc[top_indices], similarities[top_indices]
51
+
52
+ # Main pipeline with LLM integration and prompt engineering
53
+ def main_pipeline(file_path, user_query):
54
+ """End-to-end pipeline for the RAG system with Llama3.2 and prompt engineering."""
55
+ # Preprocess data
56
+ data = preprocess_data(file_path)
57
+ knowledge_base = create_knowledge_base(data)
58
+
59
+ # Create TF-IDF Vectorizer
60
+ vectorizer = TfidfVectorizer()
61
+ vectorizer.fit(knowledge_base['Knowledge_Text'])
62
+
63
+ # Query the knowledge base
64
+ results, scores = query_knowledge_base(knowledge_base, user_query, vectorizer)
65
+
66
+ # Format retrieved knowledge for LLM input
67
+ retrieved_text = "\n".join(results['Knowledge_Text'].tolist())
68
+
69
+ # Use Llama3.2 for question answering with prompt engineering
70
+ llm = pipeline("text-generation", model="meta-llama/Llama-3.2-1B-Instruct")
71
+ prompt = (
72
+ f"You are an expert sports analyst. Based on the following training data, provide a detailed and insightful answer to the user's question. "
73
+ f"Always include relevant numerical data in your response. Limit your response to a maximum of 200 words.\n\n"
74
+ f"Training Data:\n{retrieved_text}\n\n"
75
+ f"User Question: {user_query}\n\nAnswer:"
76
+ )
77
+ response = llm(prompt, max_new_tokens=200, num_return_sequences=1)
78
+
79
+ # Extract the answer part only
80
+ answer = response[0]['generated_text'].split("Answer:", 1)[-1].strip()
81
+ return answer
82
+
83
+ # Gradio interface
84
+ def query_interface(file_path, user_query):
85
+ try:
86
+ result = main_pipeline(file_path.name, user_query)
87
+ return result
88
+ except Exception as e:
89
+ return str(e)
90
+
91
+ # Launch Gradio app
92
+ file_input = gr.File(label="Upload CSV File")
93
+ text_input = gr.Textbox(label="Ask a Question", lines=2, placeholder="Enter your query here...")
94
+ output = gr.Textbox(label="Answer")
95
+
96
+ interface = gr.Interface(
97
+ fn=query_interface,
98
+ inputs=[file_input, text_input],
99
+ outputs=output,
100
+ title="RAG Training Data Query System",
101
+ description="Upload a CSV file containing training data and ask detailed questions about it."
102
+ )
103
+
104
+ interface.launch()