Shilpaj commited on
Commit
4fd7030
·
1 Parent(s): 5ec13ed

Feat: Huggingface App

Browse files
Files changed (6) hide show
  1. .gitignore +0 -1
  2. Dockerfile +39 -0
  3. README.md +36 -0
  4. app.py +4 -0
  5. requirements.txt +11 -0
  6. templates/inference.html +49 -1
.gitignore CHANGED
@@ -9,7 +9,6 @@ __pycache__
9
  data/
10
  models/
11
  *.log
12
- *.txt
13
  *.csv
14
  *.json
15
  *.pickle
 
9
  data/
10
  models/
11
  *.log
 
12
  *.csv
13
  *.json
14
  *.pickle
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as the base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ curl \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy the requirements first to leverage Docker cache
14
+ COPY requirements.txt .
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy the application files
20
+ COPY . .
21
+
22
+ # Create necessary directories and ensure proper permissions
23
+ RUN mkdir -p data/MNIST/raw scripts/training/models \
24
+ && chmod -R 755 static \
25
+ && chmod -R 755 templates \
26
+ && chown -R nobody:nogroup static templates
27
+
28
+ # Make port 8000 available (FastAPI default port)
29
+ EXPOSE 8000
30
+
31
+ # Set environment variable for FastAPI to listen on 0.0.0.0
32
+ ENV HOST=0.0.0.0
33
+ ENV PORT=8000
34
+
35
+ # Switch to non-root user
36
+ USER nobody
37
+
38
+ # Command to run the application using uvicorn
39
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MnistStudio
3
+ emoji: 🐨
4
+ colorFrom: red
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 8000
8
+ pinned: false
9
+ license: mit
10
+ short_description: Train and perform inference on MNIST dataset
11
+ ---
12
+ # MNIST Application
13
+
14
+ ## Overview
15
+
16
+ This is a simple application that can be used to train a convolutional neural network model to classify images of handwritten digits. The same application can also be used to perform inference of the digits drawn by the user.
17
+
18
+ ## Application Description
19
+ - The landing page consists of two buttons, one for training the model and one for performing inference.
20
+ - On clicking the inference button, a new page is loaded where the user can draw a digit on the canvasand select the model to perform inference.
21
+ - The inference results are displayed on the same page.
22
+ - On clicking the training button, a new page is loaded where two buttons are displayed, one for training single model and another for comparing multiple models.
23
+ - On clicking the train single model button, a new page is loaded where the user can select following options:
24
+ - Number of kernels of three blocks of the network
25
+ - Optimizer [Admam, SGD]
26
+ - Batch Size [32, 64, 128]
27
+ - Number of Epochs [1, 2, 3]
28
+ - Once these parameters are selected, the user can click on the train button to start the training. Training and validation loss, accuracy are displayed on the same page.
29
+ - On clicking the train and compare models button, a new page is loaded where the user can select following options for both the models:
30
+ - Number of kernels of three blocks of the network for each model
31
+ - Optimizer [Admam, SGD] for each model
32
+ - Batch Size [32, 64, 128] for each model
33
+ - Number of Epochs [1, 2, 3] for each model
34
+ - Once these parameters are selected, the user can click on the train button to start the training. Training and validation loss, accuracy are displayed on the same page.
35
+
36
+
app.py CHANGED
@@ -13,6 +13,7 @@ from fastapi import BackgroundTasks
13
  import warnings
14
  import asyncio
15
  import json
 
16
 
17
  warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms")
18
 
@@ -300,6 +301,9 @@ async def perform_inference(data: dict):
300
  # Resize using PIL directly with LANCZOS
301
  image = image.resize((28, 28), Image.LANCZOS)
302
 
 
 
 
303
  # Preprocess image
304
  transform = transforms.Compose([
305
  transforms.ToTensor(),
 
13
  import warnings
14
  import asyncio
15
  import json
16
+ import numpy as np
17
 
18
  warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms")
19
 
 
301
  # Resize using PIL directly with LANCZOS
302
  image = image.resize((28, 28), Image.LANCZOS)
303
 
304
+ # Invert the image (subtract from 255 to invert grayscale)
305
+ image = Image.fromarray(255 - np.array(image))
306
+
307
  # Preprocess image
308
  transform = transforms.Compose([
309
  transforms.ToTensor(),
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ torchvision
5
+ numpy
6
+ plotly
7
+ tqdm
8
+ python-multipart
9
+ jinja2
10
+ aiofiles
11
+ websockets
templates/inference.html CHANGED
@@ -5,7 +5,7 @@
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
  <title>Test Model - MNIST</title>
7
  <link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}">
8
- <link rel="stylesheet" href="{{ url_for('static', path='/css/buttons.css') }}">
9
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
10
  </head>
11
  <body>
@@ -69,6 +69,54 @@
69
 
70
  <script src="{{ url_for('static', path='/js/inference.js') }}"></script>
71
  <script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  // Update the displayPrediction function
73
  function displayPrediction(result) {
74
  const resultDiv = document.getElementById('prediction-result');
 
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
  <title>Test Model - MNIST</title>
7
  <link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}">
8
+ <!-- <link rel="stylesheet" href="{{ url_for('static', path='/css/buttons.css') }}"> -->
9
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
10
  </head>
11
  <body>
 
69
 
70
  <script src="{{ url_for('static', path='/js/inference.js') }}"></script>
71
  <script>
72
+ async function predict() {
73
+ const modelSelect = document.getElementById('model-select');
74
+ const selectedModel = modelSelect.value;
75
+
76
+ if (!selectedModel) {
77
+ alert('Please train a model first');
78
+ return;
79
+ }
80
+
81
+ const imageData = canvas.toDataURL('image/png');
82
+
83
+ try {
84
+ const response = await fetch('/api/inference', {
85
+ method: 'POST',
86
+ headers: {
87
+ 'Content-Type': 'application/json',
88
+ },
89
+ body: JSON.stringify({
90
+ image: imageData,
91
+ model_name: selectedModel
92
+ })
93
+ });
94
+
95
+ if (!response.ok) {
96
+ const error = await response.json();
97
+ throw new Error(error.detail || 'Prediction failed');
98
+ }
99
+
100
+ const data = await response.json();
101
+ displayPrediction(data);
102
+ } catch (error) {
103
+ console.error('Error:', error);
104
+ alert(error.message || 'Error during prediction');
105
+ }
106
+ }
107
+
108
+ function displayPrediction(data) {
109
+ const resultDiv = document.getElementById('prediction-result');
110
+ resultDiv.classList.remove('hidden');
111
+ resultDiv.innerHTML = `
112
+ <h2>Prediction Result</h2>
113
+ <p class="prediction-text">Predicted Digit: ${data.prediction}</p>
114
+ <p class="model-info">Model Architecture: ${data.model_config.architecture}</p>
115
+ <p class="model-info">Optimizer: ${data.model_config.optimizer}</p>
116
+ <p class="model-info">Batch Size: ${data.model_config.batch_size}</p>
117
+ `;
118
+ }
119
+
120
  // Update the displayPrediction function
121
  function displayPrediction(result) {
122
  const resultDiv = document.getElementById('prediction-result');