| import json | |
| import numpy as np | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import torch | |
| from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
| class StoppingCriteriaSub(StoppingCriteria): | |
| def __init__(self, stops=[], encounters=1): | |
| super().__init__() | |
| self.stops = stops | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
| for stop in self.stops: | |
| if torch.all(input_ids[:, -len(stop):] == stop).item(): | |
| return True | |
| return False | |
| class Chat: | |
| def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): | |
| self.device = device | |
| self.model = model | |
| self.transform = transform | |
| self.df = dataframe | |
| self.tar_img_feats = tar_img_feats | |
| self.img_feats = None | |
| self.target_recipe = None | |
| self.messages = [] | |
| if stopping_criteria is not None: | |
| self.stopping_criteria = stopping_criteria | |
| else: | |
| stop_words_ids = [torch.tensor([2]).to(self.device)] | |
| self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
| def encode_image(self, image_path): | |
| img = Image.fromarray(image_path).convert("RGB") | |
| img = self.transform(img).unsqueeze(0) | |
| img = img.to(self.device) | |
| img_embs = self.model.visual_encoder(img) | |
| img_feats = F.normalize(self.model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() | |
| self.img_feats = img_feats | |
| self.get_target(self.img_feats, self.tar_img_feats) | |
| def get_target(self, img_feats, tar_img_feats) : | |
| score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() | |
| index = np.argsort(score)[::-1][0] + 1 | |
| print(score) | |
| self.target_recipe = self.df.iloc[index] | |
| def ask(self, msg): | |
| if "nutrition" in msg or "nutrients" in msg : | |
| return json.dumps(self.target_recipe["recipe_nutrients"], indent=4) | |
| elif "instruction" in msg : | |
| return json.dumps(self.target_recipe["recipe_instructions"], indent=4) | |
| elif "ingredients" in msg : | |
| return json.dumps(self.target_recipe["recipe_ingredients"], indent=4) | |
| elif "tag" in msg or "class" in msg : | |
| return json.dumps(self.target_recipe["tags"], indent=4) | |
| else: | |
| return "Conversational capabilities will be included later." | |