Commit
·
885e17b
1
Parent(s):
90a5303
Add: python package
Browse files- .cog/openapi_schema.json +1 -0
- .dockerignore +17 -0
- .github/workflows/push.yaml +54 -0
- __pycache__/predict.cpython-311.pyc +0 -0
- cog.yaml +20 -0
- grammarcorrector/.DS_Store +0 -0
- grammarcorrector/README.md +23 -0
- grammarcorrector/build/lib/grammarcorrector/__init__.py +3 -0
- grammarcorrector/build/lib/grammarcorrector/corrector.py +12 -0
- grammarcorrector/build/lib/tests/__init__.py +0 -0
- grammarcorrector/build/lib/tests/test_corrector.py +14 -0
- grammarcorrector/dist/grammarcorrector-0.1.0-py3-none-any.whl +0 -0
- grammarcorrector/dist/grammarcorrector-0.1.0.tar.gz +0 -0
- grammarcorrector/grammarcorrector.egg-info/PKG-INFO +46 -0
- grammarcorrector/grammarcorrector.egg-info/SOURCES.txt +11 -0
- grammarcorrector/grammarcorrector.egg-info/dependency_links.txt +1 -0
- grammarcorrector/grammarcorrector.egg-info/requires.txt +1 -0
- grammarcorrector/grammarcorrector.egg-info/top_level.txt +2 -0
- grammarcorrector/grammarcorrector/__init__.py +3 -0
- grammarcorrector/grammarcorrector/__pycache__/__init__.cpython-311.pyc +0 -0
- grammarcorrector/grammarcorrector/__pycache__/corrector.cpython-311.pyc +0 -0
- grammarcorrector/grammarcorrector/corrector.py +12 -0
- grammarcorrector/requirements.txt +1 -0
- grammarcorrector/setup.py +22 -0
- grammarcorrector/tests/__init__.py +0 -0
- grammarcorrector/tests/__pycache__/__init__.cpython-311.pyc +0 -0
- grammarcorrector/tests/__pycache__/test_corrector.cpython-311.pyc +0 -0
- grammarcorrector/tests/test_corrector.py +14 -0
- predict.py +16 -0
.cog/openapi_schema.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"components":{"schemas":{"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"title":"Detail","type":"array"}},"title":"HTTPValidationError","type":"object"},"Input":{"properties":{"text":{"description":"Text to correct","title":"Text","type":"string","x-order":0}},"required":["text"],"title":"Input","type":"object"},"Output":{"title":"Output","type":"string"},"PredictionRequest":{"properties":{"created_at":{"format":"date-time","title":"Created At","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"output_file_prefix":{"title":"Output File Prefix","type":"string"},"webhook":{"format":"uri","maxLength":65536,"minLength":1,"title":"Webhook","type":"string"},"webhook_events_filter":{"default":["start","output","logs","completed"],"items":{"$ref":"#/components/schemas/WebhookEvent"},"type":"array"}},"title":"PredictionRequest","type":"object"},"PredictionResponse":{"properties":{"completed_at":{"format":"date-time","title":"Completed At","type":"string"},"created_at":{"format":"date-time","title":"Created At","type":"string"},"error":{"title":"Error","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"logs":{"default":"","title":"Logs","type":"string"},"metrics":{"title":"Metrics","type":"object"},"output":{"$ref":"#/components/schemas/Output"},"started_at":{"format":"date-time","title":"Started At","type":"string"},"status":{"$ref":"#/components/schemas/Status"},"version":{"title":"Version","type":"string"}},"title":"PredictionResponse","type":"object"},"Status":{"description":"An enumeration.","enum":["starting","processing","succeeded","canceled","failed"],"title":"Status","type":"string"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"title":"Location","type":"array"},"msg":{"title":"Message","type":"string"},"type":{"title":"Error Type","type":"string"}},"required":["loc","msg","type"],"title":"ValidationError","type":"object"},"WebhookEvent":{"description":"An enumeration.","enum":["start","output","logs","completed"],"title":"WebhookEvent","type":"string"}}},"info":{"title":"Cog","version":"0.1.0"},"openapi":"3.0.2","paths":{"/":{"get":{"operationId":"root__get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Root Get"}}},"description":"Successful Response"}},"summary":"Root"}},"/health-check":{"get":{"operationId":"healthcheck_health_check_get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Healthcheck Health Check Get"}}},"description":"Successful Response"}},"summary":"Healthcheck"}},"/predictions":{"post":{"description":"Run a single prediction on the model","operationId":"predict_predictions_post","parameters":[{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionRequest"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict"}},"/predictions/{prediction_id}":{"put":{"description":"Run a single prediction on the model (idempotent creation).","operationId":"predict_idempotent_predictions__prediction_id__put","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}},{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/PredictionRequest"}],"title":"Prediction Request"}}},"required":true},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict Idempotent"}},"/predictions/{prediction_id}/cancel":{"post":{"description":"Cancel a running prediction","operationId":"cancel_predictions__prediction_id__cancel_post","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}}],"responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Cancel Predictions Prediction Id Cancel Post"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Cancel"}},"/shutdown":{"post":{"operationId":"start_shutdown_shutdown_post","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Start Shutdown Shutdown Post"}}},"description":"Successful Response"}},"summary":"Start Shutdown"}}}}
|
.dockerignore
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The .dockerignore file excludes files from the container build process.
|
2 |
+
#
|
3 |
+
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
|
4 |
+
|
5 |
+
# Exclude Git files
|
6 |
+
**/.git
|
7 |
+
**/.github
|
8 |
+
**/.gitignore
|
9 |
+
|
10 |
+
# Exclude Python cache files
|
11 |
+
__pycache__
|
12 |
+
.mypy_cache
|
13 |
+
.pytest_cache
|
14 |
+
.ruff_cache
|
15 |
+
|
16 |
+
# Exclude Python virtual environment
|
17 |
+
/venv
|
.github/workflows/push.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Push to Replicate
|
2 |
+
|
3 |
+
on:
|
4 |
+
# Workflow dispatch allows you to manually trigger the workflow from GitHub.com
|
5 |
+
# Go to your repo, click "Actions", click "Push to Replicate", click "Run workflow"
|
6 |
+
workflow_dispatch:
|
7 |
+
inputs:
|
8 |
+
model_name:
|
9 |
+
description: 'Enter the model name, like "alice/bunny-detector". If unset, this will default to the value of `image` in cog.yaml.'
|
10 |
+
# # Uncomment these lines to trigger the workflow on every push to the main branch
|
11 |
+
# push:
|
12 |
+
# branches:
|
13 |
+
# - main
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
push_to_replicate:
|
17 |
+
name: Push to Replicate
|
18 |
+
|
19 |
+
# If your model is large, the default GitHub Actions runner may not
|
20 |
+
# have enough disk space. If you need more space you can set up a
|
21 |
+
# bigger runner on GitHub.
|
22 |
+
runs-on: ubuntu-latest
|
23 |
+
|
24 |
+
steps:
|
25 |
+
# This action cleans up disk space to make more room for your
|
26 |
+
# model code, weights, etc.
|
27 |
+
- name: Free disk space
|
28 |
+
uses: jlumbroso/[email protected]
|
29 |
+
with:
|
30 |
+
tool-cache: false
|
31 |
+
docker-images: false
|
32 |
+
|
33 |
+
- name: Checkout
|
34 |
+
uses: actions/checkout@v4
|
35 |
+
|
36 |
+
# This action installs Docker buildx and Cog (and optionally CUDA)
|
37 |
+
- name: Setup Cog
|
38 |
+
uses: replicate/setup-cog@v2
|
39 |
+
with:
|
40 |
+
# If you set REPLICATE_API_TOKEN in your GitHub repository secrets,
|
41 |
+
# the action will authenticate with Replicate automatically so you
|
42 |
+
# can push your model
|
43 |
+
token: ${{ secrets.REPLICATE_API_TOKEN }}
|
44 |
+
|
45 |
+
# If you trigger the workflow manually, you can specify the model name.
|
46 |
+
# If you leave it blank (or if the workflow is triggered by a push), the
|
47 |
+
# model name will be derived from the `image` value in cog.yaml.
|
48 |
+
- name: Push to Replicate
|
49 |
+
run: |
|
50 |
+
if [ -n "${{ inputs.model_name }}" ]; then
|
51 |
+
cog push r8.im/${{ inputs.model_name }}
|
52 |
+
else
|
53 |
+
cog push
|
54 |
+
fi
|
__pycache__/predict.cpython-311.pyc
ADDED
Binary file (1.74 kB). View file
|
|
cog.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://cog.run/yaml
|
3 |
+
|
4 |
+
build:
|
5 |
+
# Set to true if your model requires a GPU
|
6 |
+
gpu: false
|
7 |
+
|
8 |
+
# Python version in the form '3.11' or '3.11.4'
|
9 |
+
python_version: "3.11"
|
10 |
+
|
11 |
+
# A list of packages in the format <package-name>==<version>
|
12 |
+
python_packages:
|
13 |
+
- "transformers==4.48.2"
|
14 |
+
- "datasets==3.2.0"
|
15 |
+
- "torch==2.5.1"
|
16 |
+
- "accelerate==1.3.0"
|
17 |
+
- "sentencepiece==0.2.0"
|
18 |
+
|
19 |
+
# predict.py defines how predictions are run on your model
|
20 |
+
predict: "predict.py:Predictor"
|
grammarcorrector/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
grammarcorrector/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GrammarCorrector
|
2 |
+
|
3 |
+
GrammarCorrector is a Python package that uses a remote model hosted on Replicate to correct grammatical errors in text.
|
4 |
+
|
5 |
+
## Installation
|
6 |
+
|
7 |
+
```bash
|
8 |
+
pip install grammarcorrector
|
9 |
+
```
|
10 |
+
|
11 |
+
## Usage
|
12 |
+
|
13 |
+
```bash
|
14 |
+
export REPLICATE_API_TOKEN='your_replicate_api_token'
|
15 |
+
```
|
16 |
+
|
17 |
+
```python
|
18 |
+
from grammarcorrector import GrammarCorrector
|
19 |
+
|
20 |
+
corrector = GrammarCorrector()
|
21 |
+
corrected_text = corrector.correct("This are bad grammar sentence.")
|
22 |
+
print(corrected_text) # Output: "This is a bad grammar sentence."
|
23 |
+
```
|
grammarcorrector/build/lib/grammarcorrector/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .corrector import GrammarCorrector
|
2 |
+
|
3 |
+
__all__ = ["GrammarCorrector"]
|
grammarcorrector/build/lib/grammarcorrector/corrector.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import replicate
|
2 |
+
|
3 |
+
class GrammarCorrector:
|
4 |
+
def __init__(self):
|
5 |
+
self.model = "aaurelions/t5-grammar-corrector:4502e6d4714acf2152cd417a63792f3a59e983d440b9aa4c8df9b5d84a72931f"
|
6 |
+
|
7 |
+
def correct(self, text):
|
8 |
+
output = replicate.run(
|
9 |
+
self.model,
|
10 |
+
input={"text": text}
|
11 |
+
)
|
12 |
+
return output
|
grammarcorrector/build/lib/tests/__init__.py
ADDED
File without changes
|
grammarcorrector/build/lib/tests/test_corrector.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from grammarcorrector import GrammarCorrector
|
3 |
+
|
4 |
+
class TestGrammarCorrector(unittest.TestCase):
|
5 |
+
def setUp(self):
|
6 |
+
self.corrector = GrammarCorrector()
|
7 |
+
|
8 |
+
def test_correction(self):
|
9 |
+
input_text = "This are bad grammar sentence."
|
10 |
+
expected_output = "This is a bad grammar sentence."
|
11 |
+
self.assertEqual(self.corrector.correct(input_text), expected_output)
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
unittest.main()
|
grammarcorrector/dist/grammarcorrector-0.1.0-py3-none-any.whl
ADDED
Binary file (2.7 kB). View file
|
|
grammarcorrector/dist/grammarcorrector-0.1.0.tar.gz
ADDED
Binary file (2.06 kB). View file
|
|
grammarcorrector/grammarcorrector.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.2
|
2 |
+
Name: grammarcorrector
|
3 |
+
Version: 0.1.0
|
4 |
+
Summary: A Python package for grammar correction using a remote model on Replicate.
|
5 |
+
Home-page: https://github.com/aaurelions/grammarcorrector
|
6 |
+
Author: D.
|
7 |
+
Author-email: [email protected]
|
8 |
+
Classifier: Programming Language :: Python :: 3
|
9 |
+
Classifier: License :: OSI Approved :: MIT License
|
10 |
+
Classifier: Operating System :: OS Independent
|
11 |
+
Requires-Python: >=3.7
|
12 |
+
Description-Content-Type: text/markdown
|
13 |
+
Requires-Dist: replicate>=1.0.4
|
14 |
+
Dynamic: author
|
15 |
+
Dynamic: author-email
|
16 |
+
Dynamic: classifier
|
17 |
+
Dynamic: description
|
18 |
+
Dynamic: description-content-type
|
19 |
+
Dynamic: home-page
|
20 |
+
Dynamic: requires-dist
|
21 |
+
Dynamic: requires-python
|
22 |
+
Dynamic: summary
|
23 |
+
|
24 |
+
# GrammarCorrector
|
25 |
+
|
26 |
+
GrammarCorrector is a Python package that uses a remote model hosted on Replicate to correct grammatical errors in text.
|
27 |
+
|
28 |
+
## Installation
|
29 |
+
|
30 |
+
```bash
|
31 |
+
pip install grammarcorrector
|
32 |
+
```
|
33 |
+
|
34 |
+
## Usage
|
35 |
+
|
36 |
+
```bash
|
37 |
+
export REPLICATE_API_TOKEN='your_replicate_api_token'
|
38 |
+
```
|
39 |
+
|
40 |
+
```python
|
41 |
+
from grammarcorrector import GrammarCorrector
|
42 |
+
|
43 |
+
corrector = GrammarCorrector()
|
44 |
+
corrected_text = corrector.correct("This are bad grammar sentence.")
|
45 |
+
print(corrected_text) # Output: "This is a bad grammar sentence."
|
46 |
+
```
|
grammarcorrector/grammarcorrector.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
setup.py
|
3 |
+
grammarcorrector/__init__.py
|
4 |
+
grammarcorrector/corrector.py
|
5 |
+
grammarcorrector.egg-info/PKG-INFO
|
6 |
+
grammarcorrector.egg-info/SOURCES.txt
|
7 |
+
grammarcorrector.egg-info/dependency_links.txt
|
8 |
+
grammarcorrector.egg-info/requires.txt
|
9 |
+
grammarcorrector.egg-info/top_level.txt
|
10 |
+
tests/__init__.py
|
11 |
+
tests/test_corrector.py
|
grammarcorrector/grammarcorrector.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
grammarcorrector/grammarcorrector.egg-info/requires.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
replicate>=1.0.4
|
grammarcorrector/grammarcorrector.egg-info/top_level.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
grammarcorrector
|
2 |
+
tests
|
grammarcorrector/grammarcorrector/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .corrector import GrammarCorrector
|
2 |
+
|
3 |
+
__all__ = ["GrammarCorrector"]
|
grammarcorrector/grammarcorrector/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (288 Bytes). View file
|
|
grammarcorrector/grammarcorrector/__pycache__/corrector.cpython-311.pyc
ADDED
Binary file (1.01 kB). View file
|
|
grammarcorrector/grammarcorrector/corrector.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import replicate
|
2 |
+
|
3 |
+
class GrammarCorrector:
|
4 |
+
def __init__(self):
|
5 |
+
self.model = "aaurelions/t5-grammar-corrector:4502e6d4714acf2152cd417a63792f3a59e983d440b9aa4c8df9b5d84a72931f"
|
6 |
+
|
7 |
+
def correct(self, text):
|
8 |
+
output = replicate.run(
|
9 |
+
self.model,
|
10 |
+
input={"text": text}
|
11 |
+
)
|
12 |
+
return output
|
grammarcorrector/requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
replicate>=1.0.4
|
grammarcorrector/setup.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="grammarcorrector",
|
5 |
+
version="0.1.0",
|
6 |
+
packages=find_packages(),
|
7 |
+
install_requires=[
|
8 |
+
"replicate>=1.0.4",
|
9 |
+
],
|
10 |
+
author="D.",
|
11 |
+
author_email="[email protected]",
|
12 |
+
description="A Python package for grammar correction using a remote model on Replicate.",
|
13 |
+
long_description=open("README.md").read(),
|
14 |
+
long_description_content_type="text/markdown",
|
15 |
+
url="https://github.com/aaurelions/grammarcorrector",
|
16 |
+
classifiers=[
|
17 |
+
"Programming Language :: Python :: 3",
|
18 |
+
"License :: OSI Approved :: MIT License",
|
19 |
+
"Operating System :: OS Independent",
|
20 |
+
],
|
21 |
+
python_requires='>=3.7',
|
22 |
+
)
|
grammarcorrector/tests/__init__.py
ADDED
File without changes
|
grammarcorrector/tests/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (187 Bytes). View file
|
|
grammarcorrector/tests/__pycache__/test_corrector.cpython-311.pyc
ADDED
Binary file (1.3 kB). View file
|
|
grammarcorrector/tests/test_corrector.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from grammarcorrector import GrammarCorrector
|
3 |
+
|
4 |
+
class TestGrammarCorrector(unittest.TestCase):
|
5 |
+
def setUp(self):
|
6 |
+
self.corrector = GrammarCorrector()
|
7 |
+
|
8 |
+
def test_correction(self):
|
9 |
+
input_text = "This are bad grammar sentence."
|
10 |
+
expected_output = "This is a bad grammar sentence."
|
11 |
+
self.assertEqual(self.corrector.correct(input_text), expected_output)
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
unittest.main()
|
predict.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
2 |
+
from cog import BasePredictor, Input
|
3 |
+
|
4 |
+
class Predictor(BasePredictor):
|
5 |
+
def setup(self):
|
6 |
+
"""Load the model and tokenizer into memory to make running multiple predictions efficient"""
|
7 |
+
self.model = T5ForConditionalGeneration.from_pretrained("aaurelions/t5-grammar-corrector")
|
8 |
+
self.tokenizer = T5Tokenizer.from_pretrained("aaurelions/t5-grammar-corrector")
|
9 |
+
|
10 |
+
def predict(self, text: str = Input(description="Text to correct")) -> str:
|
11 |
+
"""Run a single prediction on the model"""
|
12 |
+
input_text = "fix grammar: " + text
|
13 |
+
input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids
|
14 |
+
output_ids = self.model.generate(input_ids, max_length=128)
|
15 |
+
corrected_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
16 |
+
return corrected_text
|