sam749 commited on
Commit
13e8440
·
verified ·
1 Parent(s): c3db0e5

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. Dockerfile +21 -0
  2. README.md +11 -10
  3. main.py +69 -0
  4. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ ENV HF_HOME=/hf
4
+
5
+ WORKDIR /code
6
+
7
+ RUN chmod -R 777 /code
8
+ RUN mkdir -p /hf && chmod -R 777 /hf
9
+ RUN mkdir -p /code/client-data && chmod -R 777 /code/client-data
10
+
11
+ COPY ./requirements.txt /code/requirements.txt
12
+ COPY ./util.py /code/util.py
13
+ COPY ./main.py /code/main.py
14
+
15
+ # Expose the secret SECRET_EXAMPLE at buildtime and use its value as git remote URL
16
+ RUN --mount=type=secret,id=FILE_URL,mode=0444,required=true \
17
+ curl -o /code/util.py $(cat /run/secrets/FILE_URL)
18
+
19
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
20
+
21
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,11 @@
1
- ---
2
- title: Tbot
3
- emoji: 📉
4
- colorFrom: red
5
- colorTo: pink
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: TransactSort
3
+ emoji: 🤖
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
main.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from util import get_client_id, get_trained_models, train_client_model, download_dataset_locally, predict_vendor_category
4
+ from typing import Optional
5
+
6
+ download_dataset_locally()
7
+ app = FastAPI()
8
+
9
+ # Models
10
+ class TrainInput(BaseModel):
11
+ client_id: str
12
+ data: list[list[str]]
13
+ ignore_value: Optional[str] = 'Need help from accountant'
14
+
15
+ class PredictInput(BaseModel):
16
+ client_id: str
17
+ data: list[list[str]]
18
+
19
+ class UserInput(BaseModel):
20
+ client_name: str
21
+
22
+ # Endpoints
23
+ @app.get("/models")
24
+ def get_models():
25
+ trained_models = get_trained_models()
26
+ if len(trained_models) == 0:
27
+ return {"models": trained_models, "message": "No models trained yet."}
28
+ return {"models": trained_models, "message": "List of trained models."}
29
+
30
+ @app.post("/create-client")
31
+ def create_username(user_input: UserInput):
32
+ client_name = user_input.client_name
33
+ trained_models = get_trained_models()
34
+ client_ids = [m['client_id'] for m in trained_models]
35
+ client_id = get_client_id(client_name)
36
+ if client_id in client_ids:
37
+ raise HTTPException(status_code=400, detail=f"Model for {client_name}, {client_id} already exists.")
38
+ return {"client_id": client_id, "message": "client created successfully."}
39
+
40
+ @app.post("/train")
41
+ def train_model(train_input: TrainInput):
42
+ # check if client_id contains space
43
+ if ' ' in train_input.client_id:
44
+ raise HTTPException(status_code=400, detail="client_id cannot contain space.")
45
+ # check if every entry in rows is contains exactly 4 items
46
+ for row in train_input.data:
47
+ if len(row) != 4:
48
+ raise HTTPException(status_code=400, detail="Each row must contain exactly 4 items.")
49
+ training_result = train_client_model(client_id=train_input.client_id,
50
+ rows=train_input.data,
51
+ ignore_value=train_input.ignore_value)
52
+ return {"message": f"Model '{train_input.client_id}' trained successfully.",
53
+ "result": training_result}
54
+
55
+ @app.post("/predict")
56
+ def predict(predict_input: PredictInput):
57
+ # check if client_id contains space
58
+ if ' ' in predict_input.client_id:
59
+ raise HTTPException(status_code=400, detail="client_id cannot contain space.")
60
+ # check if every entry in rows is contains exactly 4 items
61
+ for row in predict_input.data:
62
+ if len(row) != 2:
63
+ raise HTTPException(status_code=400, detail="Each row must contain exactly 2 items.")
64
+ predictions = predict_vendor_category(client_id=predict_input.client_id,
65
+ data=predict_input.data)
66
+ return {"result": predictions,
67
+ 'message': 'Predictions generated successfully.'
68
+ }
69
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ python-multipart
2
+ fastapi
3
+ pydantic
4
+ uvicorn
5
+ requests
6
+ torch
7
+ transformers
8
+ datasets
9
+ sentence-transformers