Lazar Radojevic commited on
Commit
268c7f9
·
1 Parent(s): b9115ea

copy from other repo

Browse files
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python image from the Docker Hub
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install Poetry
8
+ RUN pip install poetry
9
+
10
+ # Copy only the pyproject.toml and poetry.lock files to install dependencies first
11
+ COPY pyproject.toml poetry.lock ./
12
+
13
+ # Install dependencies using Poetry
14
+ RUN poetry config virtualenvs.create false && poetry install --only=main
15
+
16
+ # Copy the rest of the application code to the working directory
17
+ COPY . .
18
+
19
+ # Expose the port FastAPI will run on
20
+ EXPOSE 7860
21
+
22
+ # Command to run the FastAPI application
23
+ CMD ["poetry", "run", "uvicorn", "run:app", "--host", "0.0.0.0", "--port", "7860", "--reload"]
__init__.py ADDED
File without changes
poe/common-tasks.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file defines common tasks that most python projects can benefit from
2
+
3
+ [tool.poe.tasks.format-isort]
4
+ help = "Format code with isort"
5
+ cmd = "isort ."
6
+
7
+ [tool.poe.tasks.format-black]
8
+ help = "Format code with black"
9
+ cmd = "black ."
10
+
11
+ [tool.poe.tasks.format]
12
+ help = "Run code formating tools"
13
+ sequence = ["format-isort", "format-black"]
14
+
15
+ [tool.poe.tasks.style-black]
16
+ help = "Validate black code style"
17
+ cmd = "black . --check --diff"
18
+
19
+ [tool.poe.tasks.style-isort]
20
+ help = "Validate isort code style"
21
+ cmd = "isort . --check --diff"
22
+
23
+ [tool.poe.tasks.style]
24
+ help = "Validate code style"
25
+ sequence = ["style-isort", "style-black"]
26
+
27
+ [tool.poe.tasks.types]
28
+ help = "Run the type checker"
29
+ cmd = "mypy . --ignore-missing-imports --check-untyped-defs --install-types --non-interactive"
30
+
31
+ [tool.poe.tasks.lint]
32
+ help = "Evaluate ruff rules"
33
+ cmd = "ruff check ."
34
+
35
+ [tool.poe.tasks.test]
36
+ help = "Run unit tests"
37
+ cmd = "pytest -p no:cacheprovider"
38
+
39
+ [tool.poe.tasks.clean]
40
+ help = "Remove automatically generated files"
41
+ cmd = """
42
+ rm -rf dist
43
+ .mypy_cache
44
+ .pytest_cache
45
+ .ruff_cache
46
+ ./**/__pycache__/
47
+ ./**/*.pyc
48
+ """
49
+
50
+ [tool.poe.tasks.check]
51
+ help = "Run all checks on the code base"
52
+ sequence = ["style", "types", "lint", "clean"]
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "smart-cat-assignment-backend"
3
+ version = "0.0.1"
4
+ description = "SmartCat Assignment"
5
+ authors = ["Lazar Radojevic <[email protected]>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.10"
10
+ mypy = "^1.8.0"
11
+ ruff = "^0.3.2"
12
+ datasets = "^2.20.0"
13
+ sentence-transformers = "^3.0.1"
14
+ numpy = "1.26.4"
15
+ fastapi = "^0.111.1"
16
+ uvicorn = "^0.30.3"
17
+
18
+ [tool.poetry.group.dev.dependencies]
19
+ black = "^24.1.1"
20
+ poethepoet = "^0.24.4"
21
+ isort = "^5.13.2"
22
+
23
+ [tool.isort]
24
+ profile = "black"
25
+
26
+ [tool.poe]
27
+ include = "./poe/common-tasks.toml"
run.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ from src.search_engine import PromptSearchEngine
5
+ from src.prompt_loader import PromptLoader
6
+
7
+ # Constants
8
+ SEED = 42
9
+ DATA_SIZE = 100
10
+
11
+ # Initialize the prompt loader and search engine
12
+ prompts = PromptLoader(seed=SEED).load_data(size=DATA_SIZE)
13
+ engine = PromptSearchEngine(prompts)
14
+
15
+ # Initialize FastAPI
16
+ app = FastAPI()
17
+
18
+
19
+ # Request and Response Models
20
+ class QueryRequest(BaseModel):
21
+ query: str
22
+ n: int = 5
23
+
24
+
25
+ class SimilarPrompt(BaseModel):
26
+ score: float
27
+ prompt: str
28
+
29
+
30
+ class QueryResponse(BaseModel):
31
+ similar_prompts: List[SimilarPrompt]
32
+
33
+
34
+ # API endpoint
35
+ @app.post("/most_similar", response_model=QueryResponse)
36
+ async def get_most_similar(query_request: QueryRequest):
37
+ try:
38
+ similar_prompts = engine.most_similar(
39
+ query=query_request.query, n=query_request.n
40
+ )
41
+ response = QueryResponse(
42
+ similar_prompts=[
43
+ SimilarPrompt(score=score, prompt=prompt)
44
+ for score, prompt in similar_prompts
45
+ ]
46
+ )
47
+ return response
48
+ except Exception as e:
49
+ raise HTTPException(status_code=500, detail=str(e))
50
+
51
+
52
+ # Run the server with: uvicorn main:app --reload
53
+ if __name__ == "__main__":
54
+ import uvicorn
55
+
56
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
src/__init__.py ADDED
File without changes
src/prompt_loader.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ from datasets import load_dataset
3
+ import random
4
+
5
+
6
+ class PromptLoader:
7
+ def __init__(self, seed: int = 42) -> None:
8
+ self.randomizer = random.Random(seed)
9
+ self.data: Optional[List[str]] = None
10
+
11
+ def _load_data(self) -> None:
12
+ self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][
13
+ "prompt"
14
+ ]
15
+
16
+ def load_data(self, size: Optional[int] = None) -> List[str]:
17
+ if not self.data:
18
+ self._load_data()
19
+
20
+ if size:
21
+ if size > len(self.data):
22
+ raise ValueError("Not enough samples available!")
23
+ return self.randomizer.sample(self.data, size)
24
+ else:
25
+ return self.data
src/search_engine.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Sequence, Tuple
2
+
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ from src.vectorizer import Vectorizer
6
+ from src.similarity_scorer import SimilarityScorer
7
+
8
+
9
+ class PromptSearchEngine:
10
+ def __init__(self, prompts: Sequence[str]) -> None:
11
+ """Initialize search engine by vectorizing prompt corpus.
12
+ Vectorized prompt corpus should be used to find the top n most
13
+ similar prompts w.r.t. user’s input prompt.
14
+ Args:
15
+ prompts: The sequence of raw prompts from the dataset.
16
+ """
17
+ self.vectorizer = Vectorizer(SentenceTransformer("all-MiniLM-L6-v2"))
18
+ self.scorer = SimilarityScorer()
19
+
20
+ self.prompts = prompts
21
+ self.embeddings = self.vectorizer.transform(prompts)
22
+
23
+ def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]:
24
+ """Return top n most similar prompts from corpus.
25
+ Input query prompt should be vectorized with chosen Vectorizer.
26
+ After
27
+ that, use the cosine_similarity function to get the top n most
28
+ similar
29
+ prompts from the corpus.
30
+ Args:
31
+ query: The raw query prompt input from the user.
32
+ n: The number of similar prompts returned from the corpus.
33
+ Returns:
34
+ The list of top n most similar prompts from the corpus along
35
+ with similarity scores. Note that returned prompts are
36
+ verbatim.
37
+ """
38
+ query_embedding = self.vectorizer.transform(query)
39
+
40
+ similarities = self.scorer.cosine_similarity(query_embedding, self.embeddings)
41
+
42
+ # Get the top n indices with highest similarity scores
43
+ top_n_indices = similarities.argsort()[-n:][::-1]
44
+
45
+ # Retrieve the top n most similar prompts along with their similarity scores
46
+ top_n_similar_prompts = [
47
+ (similarities[i], self.prompts[i]) for i in top_n_indices
48
+ ]
49
+
50
+ return top_n_similar_prompts
src/similarity_scorer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class SimilarityScorer:
5
+
6
+ def cosine_similarity(
7
+ self,
8
+ query_vector: np.ndarray,
9
+ corpus_vectors: np.ndarray,
10
+ ) -> np.ndarray:
11
+ """Calculate cosine similarity between prompt vectors.
12
+ Args:
13
+ query_vector: Vectorized prompt query of shape (1, D).
14
+ corpus_vectors: Vectorized prompt corpus of shape (N, D).
15
+ Returns: The vector of shape (N,) with values in range [-1, 1] where 1
16
+ is max similarity i.e., two vectors are the same.
17
+ """
18
+
19
+ # Normalize the query vector
20
+ query_norm = np.linalg.norm(query_vector)
21
+ if query_norm == 0:
22
+ raise ValueError("The query vector cannot be zero.")
23
+ query_vector = query_vector / query_norm
24
+
25
+ # Normalize the corpus vectors
26
+ corpus_norms = np.linalg.norm(corpus_vectors, axis=1)
27
+ if np.any(corpus_norms == 0):
28
+ raise ValueError("The corpus contains zero vectors.")
29
+ normalized_corpus = corpus_vectors / corpus_norms[:, np.newaxis]
30
+
31
+ # Calculate cosine similarity
32
+ similarities = np.dot(normalized_corpus, query_vector.T)
33
+
34
+ return similarities
src/vectorizer.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ import numpy as np
4
+
5
+
6
+ class Vectorizer:
7
+ def __init__(self, model) -> None:
8
+ """Initialize the vectorizer with a pre-trained embedding model.
9
+ Args:
10
+ model: The pre-trained embedding model to use for transforming
11
+ prompts.
12
+ """
13
+ self.model = model
14
+
15
+ def transform(self, prompts: Sequence[str]) -> np.ndarray:
16
+ """Transform texts into numerical vectors using the specified
17
+ model.
18
+ Args:
19
+ prompts: The sequence of raw corpus prompts. Returns:
20
+ Vectorized
21
+ prompts as a numpy array.
22
+ """
23
+ return self.model.encode(prompts)