dnzblgn commited on
Commit
1ec8bd7
Β·
verified Β·
1 Parent(s): b247b6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -11
app.py CHANGED
@@ -10,7 +10,80 @@ from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain_community.llms import HuggingFaceEndpoint
12
  from langchain_huggingface import HuggingFaceEmbeddings
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # βœ… Use a strong sentence embedding model
15
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
16
 
@@ -103,7 +176,7 @@ def validate_query_semantically(query, retrieved_docs):
103
 
104
  print(f"πŸ” Semantic Similarity Score: {similarity_score}")
105
 
106
- return similarity_score >= 0.4 # πŸ”₯ Stricter threshold to ensure correctness
107
 
108
 
109
  def handle_query(query, history, retriever, qa_chain, embeddings):
@@ -163,24 +236,64 @@ def initialize_chatbot(vector_db):
163
  return retriever, qa_chain, embeddings
164
 
165
 
 
 
 
 
 
 
 
166
  def demo():
167
- """ βœ… Starts the chatbot application using Gradio. """
 
 
 
 
168
  retriever, qa_chain, embeddings = initialize_chatbot(create_db(load_documents()))
169
-
170
  with gr.Blocks() as app:
171
- gr.Markdown("### πŸ€– **Fastener Agent** πŸ“š")
172
- chatbot = gr.Chatbot()
173
- query_input = gr.Textbox(label="Ask me a question")
174
- query_btn = gr.Button("Ask")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  def user_query_handler(query, history):
177
  return handle_query(query, history, retriever, qa_chain, embeddings)
178
 
179
- query_btn.click(user_query_handler, inputs=[query_input, chatbot], outputs=[chatbot, query_input])
180
- query_input.submit(user_query_handler, inputs=[query_input, chatbot], outputs=[chatbot, query_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  app.launch()
183
 
184
-
185
  if __name__ == "__main__":
186
  demo()
 
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain_community.llms import HuggingFaceEndpoint
12
  from langchain_huggingface import HuggingFaceEmbeddings
13
+ import torch
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ from torchvision.models import resnet50, ResNet50_Weights
17
+ from torchvision import transforms, models
18
+
19
+
20
+ class GeometryImageClassifier:
21
+ def __init__(self):
22
+ # Load ResNet50 but only use it for feature extraction
23
+ self.model = models.resnet50(weights='DEFAULT')
24
+ # Remove the final classification layer
25
+ self.model.fc = torch.nn.Identity()
26
+ self.model.eval()
27
+
28
+ self.transform = transforms.Compose([
29
+ transforms.Resize((224, 224)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
32
+ ])
33
+
34
+ # Pre-computed embeddings for our 3 reference images with manual labels
35
+ self.reference_embeddings = {
36
+ "flat.png": {
37
+ "embedding": None, # Will be computed on first run
38
+ "label": "Flat or Sheet-Based"
39
+ },
40
+ "cylindrical.png": {
41
+ "embedding": None,
42
+ "label": "Cylindrical"
43
+ },
44
+ "complex.png": {
45
+ "embedding": None,
46
+ "label": "Complex Multi Axis Geometry"
47
+ }
48
+ }
49
+
50
+ def compute_embedding(self, images):
51
+ img = Image.open(images).convert('RGB')
52
+ img_tensor = self.transform(img).unsqueeze(0)
53
+
54
+ with torch.no_grad():
55
+ embedding = self.model(img_tensor)
56
+ return embedding.squeeze().numpy()
57
+
58
+ def initialize_reference_embeddings(self, reference_folder):
59
+ for image_name in self.reference_embeddings.keys():
60
+ images = f"{reference_folder}/{image_name}"
61
+ self.reference_embeddings[image_name]["embedding"] = self.compute_embedding(images)
62
+
63
+ def find_closest_geometry(self, query_embedding):
64
+ best_similarity = -1
65
+ best_label = None
66
+
67
+ for ref_data in self.reference_embeddings.values():
68
+ similarity = cosine_similarity(
69
+ query_embedding.reshape(1, -1),
70
+ ref_data["embedding"].reshape(1, -1)
71
+ )[0][0]
72
+
73
+ if similarity > best_similarity:
74
+ best_similarity = similarity
75
+ best_label = ref_data["label"]
76
+
77
+ return best_label
78
+
79
+ def process_image(self, images):
80
+ # Compute embedding for the input image
81
+ query_embedding = self.compute_embedding(images)
82
+
83
+ # Find the closest matching reference geometry
84
+ return self.find_closest_geometry(query_embedding)
85
+
86
+
87
  # βœ… Use a strong sentence embedding model
88
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
89
 
 
176
 
177
  print(f"πŸ” Semantic Similarity Score: {similarity_score}")
178
 
179
+ return similarity_score >= 0.3 # πŸ”₯ Stricter threshold to ensure correctness
180
 
181
 
182
  def handle_query(query, history, retriever, qa_chain, embeddings):
 
236
  return retriever, qa_chain, embeddings
237
 
238
 
239
+ def process_image_and_generate_query(image):
240
+ classifier = GeometryImageClassifier()
241
+ geometry_type = classifier.process_image(image)
242
+
243
+ query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
244
+ return geometry_type, query
245
+
246
  def demo():
247
+ # Initialize classifier once at startup
248
+ classifier = GeometryImageClassifier()
249
+ classifier.initialize_reference_embeddings("images")
250
+
251
+ # Initialize chatbot components
252
  retriever, qa_chain, embeddings = initialize_chatbot(create_db(load_documents()))
253
+
254
  with gr.Blocks() as app:
255
+ gr.Markdown("### πŸ€– **Fastener Agent with Image Recognition** πŸ“š")
256
+
257
+ with gr.Row():
258
+ with gr.Column(scale=1):
259
+ image_input = gr.Image(type="filepath", label="Upload Geometry Image")
260
+ geometry_label = gr.Textbox(label="Detected Geometry Type", interactive=False)
261
+
262
+ with gr.Column(scale=2):
263
+ chatbot = gr.Chatbot()
264
+ query_input = gr.Textbox(label="Ask me a question")
265
+ query_btn = gr.Button("Ask")
266
+
267
+ def image_upload_handler(image):
268
+ if image is None:
269
+ return "", ""
270
+ # Use the initialized classifier
271
+ geometry_type = classifier.process_image(image)
272
+ suggested_query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
273
+ return geometry_type, suggested_query
274
 
275
  def user_query_handler(query, history):
276
  return handle_query(query, history, retriever, qa_chain, embeddings)
277
 
278
+ image_input.change(
279
+ image_upload_handler,
280
+ inputs=[image_input],
281
+ outputs=[geometry_label, query_input]
282
+ )
283
+
284
+ query_btn.click(
285
+ user_query_handler,
286
+ inputs=[query_input, chatbot],
287
+ outputs=[chatbot, query_input]
288
+ )
289
+
290
+ query_input.submit(
291
+ user_query_handler,
292
+ inputs=[query_input, chatbot],
293
+ outputs=[chatbot, query_input]
294
+ )
295
 
296
  app.launch()
297
 
 
298
  if __name__ == "__main__":
299
  demo()