pavi156 commited on
Commit
9f8b30f
·
1 Parent(s): 425b965

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -30
app.py CHANGED
@@ -3,37 +3,65 @@ import torch
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
 
6
- # Load the trained model
7
- model_path = "cifar_net.pth"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- state_dict = torch.load(model_path)
10
 
11
- model = torch.nn.Module()
12
- model.load_state_dict(state_dict)
13
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Prepare the image for prediction
16
- image_path = 'download.jpg'
17
- image = Image.open(image_path)
18
 
19
- # Transform the image to match CIFAR-10 format
20
- transform = transforms.Compose([
21
- transforms.Resize((32, 32)),
22
- transforms.ToTensor(),
23
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize with CIFAR-10 mean and std
24
- ])
25
- input_image = transform(image).unsqueeze(0)
26
-
27
- # Make predictions
28
- with torch.no_grad():
29
- outputs = model(input_image)
30
-
31
- # Retrieve the predicted class label
32
- _, predicted = torch.max(outputs, 1)
33
- class_index = predicted.item()
34
-
35
- # Load the CIFAR-10 class labels
36
- classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
37
-
38
- # Print the predicted class label
39
- print('Predicted class label:', classes[class_index])
 
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
 
6
+ file_urls = [
7
+ "https://www.bing.com/images/search?view=detailV2&ccid=YaFiK%2bN6&id=D84622E2396A39F168D279F32AC31F05096187AB&thid=OIP.YaFiK-N6iDdJR6B6DMBHpgHaFj&mediaurl=https%3a%2f%2fwww.practicalcaravan.com%2fwp-content%2fuploads%2f2016%2f03%2f5907569-scaled.jpg&exph=1921&expw=2560&q=audi+a4+car+image&simid=608053806389945942&FORM=IRPRST&ck=7DDB4BC7AA27F8E3EDA4433E669D3CC4&selectedIndex=6&ajaxhist=0&ajaxserp=0","https://www.bing.com/images/search?view=detailV2&ccid=CHONQxwQ&id=B8BCD1A5420658017C772CF149AFB7D24F2F8322&thid=OIP.CHONQxwQrclsFp-VXh4aOQHaFD&mediaurl=https%3a%2f%2fs3-eu-west-1.amazonaws.com%2feurekar-v2%2fuploads%2fimages%2foriginal%2fa4salfront.jpg&exph=1025&expw=1500&q=audi+a4+car+image&simid=608024308599848180&FORM=IRPRST&ck=3A2EA226332024ECB13B2F27682C15CA&selectedIndex=3&ajaxhist=0&ajaxserp=0"
8
+ ]
9
+
10
+ def download_file(url, save_name):
11
+ url = url
12
+ if not os.path.exists(save_name):
13
+ file = requests.get(url)
14
+ open(save_name, 'wb').write(file.content)
15
+
16
+ for i, url in enumerate(file_urls):
17
+ if 'mp4' in file_urls[i]:
18
+ download_file(
19
+ file_urls[i],
20
+ f"video.mp4"
21
+ )
22
+ else:
23
+ download_file(
24
+ file_urls[i],
25
+ f"image_{i}.jpg"
26
+ )
27
+
28
+ model = 'cifar_net.pth'
29
+ path = [['image_0.jpg'], ['image_1.jpg']]
30
+ video_path = [['video.mp4']]
31
 
 
32
 
33
+ def show_preds_image(image_path):
34
+ image = cv2.imread(image_path)
35
+ outputs = model.predict(source=image_path)
36
+ results = outputs[0].cpu().numpy()
37
+ for i, det in enumerate(results.boxes.xyxy):
38
+ cv2.rectangle(
39
+ image,
40
+ (int(det[0]), int(det[1])),
41
+ (int(det[2]), int(det[3])),
42
+ color=(0, 0, 255),
43
+ thickness=2,
44
+ lineType=cv2.LINE_AA
45
+ )
46
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
47
+
48
+ inputs_image = [
49
+ gr.components.Image(type="filepath", label="Input Image"),
50
+ ]
51
+ outputs_image = [
52
+ gr.components.Image(type="numpy", label="Output Image"),
53
+ ]
54
+ interface_image = gr.Interface(
55
+ fn=show_preds_image,
56
+ inputs=inputs_image,
57
+ outputs=outputs_image,
58
+ title="Car detector",
59
+ examples=path,
60
+ cache_examples=False,
61
+ )
62
 
 
 
 
63
 
64
+ gr.TabbedInterface(
65
+ [interface_image],
66
+ tab_names=['Image inference']
67
+ ).queue().launch()