Spaces:
Sleeping
Sleeping
ezequiellopez
commited on
Commit
·
86e971c
1
Parent(s):
8ebe686
integration_tests fixes
Browse files- app/_models/README.md +44 -0
- app/_models/__init__.py +0 -0
- app/_models/fake.py +132 -0
- app/_models/fake_test.py +45 -0
- app/_models/request.py +118 -0
- app/_models/requirements.txt +4 -0
- app/_models/response.py +26 -0
- app/_models/survey.py +142 -0
- app/app.py +18 -12
- app/modules/classify.py +4 -2
app/_models/README.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pydantic models for the PRC API schema
|
2 |
+
|
3 |
+
You can use these models in your Python code, both to generate valid data, and to parse incoming data.
|
4 |
+
|
5 |
+
Using the models ensures that your data has been at least somewhat validated. If the schema changes and your code needs an update, you're more likely to be able to tell right away.
|
6 |
+
|
7 |
+
## Parsing a request
|
8 |
+
|
9 |
+
### With FastAPI
|
10 |
+
|
11 |
+
If you're using fastapi, you can use the models right in your server:
|
12 |
+
|
13 |
+
```python
|
14 |
+
from models.request import RankingRequest
|
15 |
+
from models.response import RankingResponse
|
16 |
+
|
17 |
+
@app.post("/rank")
|
18 |
+
def rank(ranking_request: RankingRequest) -> RankingResponse:
|
19 |
+
...
|
20 |
+
# You can return a RankingResponse here, or a dict with the correct keys and
|
21 |
+
# pydantic will figure it out.
|
22 |
+
```
|
23 |
+
|
24 |
+
If you specify `RankingResponse` as your reeturn type, you will get validation of your response for free.
|
25 |
+
|
26 |
+
For a complete example, check out `../fastapi_nltk/`
|
27 |
+
|
28 |
+
### Otherwise
|
29 |
+
|
30 |
+
If you'd like to parse a request directly, here is how:
|
31 |
+
|
32 |
+
```python
|
33 |
+
from models.request import RankingRequest
|
34 |
+
|
35 |
+
loaded_request = RankingRequest.model_validate_json(json_data)
|
36 |
+
```
|
37 |
+
|
38 |
+
## Generating fake data
|
39 |
+
|
40 |
+
There is a fake data generator in `fake.py`. If you run it directly it'll print some. You can also import it and run `fake_request()` or `fake_response()`. Take a look at the test for a usage example.
|
41 |
+
|
42 |
+
## More
|
43 |
+
|
44 |
+
[The pydantic docs](https://docs.pydantic.dev/latest/)
|
app/_models/__init__.py
ADDED
File without changes
|
app/_models/fake.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import inspect
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
from random import randint
|
7 |
+
from uuid import uuid4
|
8 |
+
|
9 |
+
parentdir = os.path.dirname( # make it possible to import from ../ in a reliable way
|
10 |
+
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
11 |
+
)
|
12 |
+
sys.path.insert(0, parentdir)
|
13 |
+
|
14 |
+
from faker import Faker
|
15 |
+
|
16 |
+
fake = Faker(locale="la") # remove locale to get rid of the fake latin
|
17 |
+
|
18 |
+
from models.request import ContentItem, RankingRequest, Session
|
19 |
+
from models.response import RankingResponse
|
20 |
+
from models.survey import SurveyResponse
|
21 |
+
|
22 |
+
|
23 |
+
def fake_request(n_posts=1, n_comments=0, platform="reddit"):
|
24 |
+
posts = [fake_item(platform=platform, type="post") for _ in range(n_posts)]
|
25 |
+
comments = []
|
26 |
+
for post in posts:
|
27 |
+
last_comment_id = None
|
28 |
+
for _ in range(n_comments):
|
29 |
+
comments.append(
|
30 |
+
fake_item(
|
31 |
+
platform=platform,
|
32 |
+
type="comment",
|
33 |
+
post_id=post.id,
|
34 |
+
parent_id=last_comment_id,
|
35 |
+
)
|
36 |
+
)
|
37 |
+
last_comment_id = comments[-1].id
|
38 |
+
|
39 |
+
return RankingRequest(
|
40 |
+
session=Session(
|
41 |
+
user_id=str(uuid4()),
|
42 |
+
user_name_hash=hashlib.sha256(fake.name().encode()).hexdigest(),
|
43 |
+
cohort="AB",
|
44 |
+
platform=platform,
|
45 |
+
current_time=time.time(),
|
46 |
+
),
|
47 |
+
survey=SurveyResponse(
|
48 |
+
party_id="democrat",
|
49 |
+
support="strong",
|
50 |
+
party_lean="democrat",
|
51 |
+
sex="female",
|
52 |
+
age=3,
|
53 |
+
education=4,
|
54 |
+
ideology=5,
|
55 |
+
income=6,
|
56 |
+
ethnicity="native_american",
|
57 |
+
socmed_use=7,
|
58 |
+
browser_perc=0.8,
|
59 |
+
mobile_perc=0.2,
|
60 |
+
),
|
61 |
+
items=posts + comments,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def fake_item(platform="reddit", type="post", post_id=None, parent_id=None):
|
66 |
+
if platform == "reddit":
|
67 |
+
engagements = {
|
68 |
+
"upvote": randint(0, 50),
|
69 |
+
"downvote": randint(0, 50),
|
70 |
+
"comment": randint(0, 50),
|
71 |
+
"award": randint(0, 50),
|
72 |
+
}
|
73 |
+
elif platform == "twitter":
|
74 |
+
engagements = {
|
75 |
+
"like": randint(0, 50),
|
76 |
+
"retweet": randint(0, 50),
|
77 |
+
"comment": randint(0, 50),
|
78 |
+
"share": randint(0, 50),
|
79 |
+
}
|
80 |
+
elif platform == "facebook":
|
81 |
+
engagements = {
|
82 |
+
"like": randint(0, 50),
|
83 |
+
"love": randint(0, 50),
|
84 |
+
"care": randint(0, 50),
|
85 |
+
"haha": randint(0, 50),
|
86 |
+
"wow": randint(0, 50),
|
87 |
+
"sad": randint(0, 50),
|
88 |
+
"angry": randint(0, 50),
|
89 |
+
"comment": randint(0, 50),
|
90 |
+
"share": randint(0, 50),
|
91 |
+
}
|
92 |
+
else:
|
93 |
+
raise ValueError(f"Unknown platform: {platform}")
|
94 |
+
|
95 |
+
return ContentItem(
|
96 |
+
id=str(uuid4()),
|
97 |
+
text=fake.text(),
|
98 |
+
post_id=post_id,
|
99 |
+
parent_id=parent_id,
|
100 |
+
author_name_hash=hashlib.sha256(fake.name().encode()).hexdigest(),
|
101 |
+
type=type,
|
102 |
+
created_at=time.time(),
|
103 |
+
embedded_urls=[fake.url() for _ in range(randint(0, 3))],
|
104 |
+
engagements=engagements,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def fake_response(ids, n_new_items=1):
|
109 |
+
new_items = [fake_new_item() for _ in range(n_new_items)]
|
110 |
+
|
111 |
+
ids = list(ids) + [item["id"] for item in new_items]
|
112 |
+
|
113 |
+
return RankingResponse(ranked_ids=ids, new_items=new_items)
|
114 |
+
|
115 |
+
|
116 |
+
def fake_new_item():
|
117 |
+
return {
|
118 |
+
"id": str(uuid4()),
|
119 |
+
"url": fake.url(),
|
120 |
+
}
|
121 |
+
|
122 |
+
|
123 |
+
# if run from command line
|
124 |
+
if __name__ == "__main__":
|
125 |
+
request = fake_request(n_posts=1, n_comments=2)
|
126 |
+
print("Request:")
|
127 |
+
print(request.model_dump_json(indent=2))
|
128 |
+
|
129 |
+
# use ids from request
|
130 |
+
response = fake_response([item.id for item in request.items], 2)
|
131 |
+
print("\nResponse:")
|
132 |
+
print(response.model_dump_json(indent=2))
|
app/_models/fake_test.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
parentdir = os.path.dirname( # make it possible to import from ../ in a reliable way
|
6 |
+
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
7 |
+
)
|
8 |
+
sys.path.insert(0, parentdir)
|
9 |
+
|
10 |
+
from models import fake
|
11 |
+
from models.request import RankingRequest
|
12 |
+
|
13 |
+
|
14 |
+
def test_fake_request():
|
15 |
+
# this test's purpose is mostly to run the code to make sure it doesn't
|
16 |
+
# have any validation errors. pydantic will make sure it has the right fields.
|
17 |
+
request = fake.fake_request(n_posts=5)
|
18 |
+
assert len(request.items) == 5
|
19 |
+
|
20 |
+
# all ids are unique
|
21 |
+
assert len(set(item.id for item in request.items)) == 5
|
22 |
+
|
23 |
+
request = fake.fake_request(n_posts=5, n_comments=2, platform="twitter")
|
24 |
+
assert len(request.items) == 15
|
25 |
+
assert request.session.platform == "twitter"
|
26 |
+
|
27 |
+
|
28 |
+
def test_fake_response():
|
29 |
+
ids = [str(i) for i in range(5)]
|
30 |
+
|
31 |
+
response = fake.fake_response(ids, 2)
|
32 |
+
assert len(response.ranked_ids) == 7
|
33 |
+
|
34 |
+
# all ids are unique
|
35 |
+
assert len(set(id for id in response.ranked_ids)) == 7
|
36 |
+
|
37 |
+
|
38 |
+
def test_load_fake_data():
|
39 |
+
# This really just exercises pydantic, and is mostly an example
|
40 |
+
# of how to load json data
|
41 |
+
request = fake.fake_request(5)
|
42 |
+
json_data = request.model_dump_json()
|
43 |
+
|
44 |
+
loaded_request = RankingRequest.model_validate_json(json_data)
|
45 |
+
assert len(loaded_request.items) == 5
|
app/_models/request.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from typing import Literal, Optional, Union
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field, HttpUrl
|
5 |
+
from pydantic.types import NonNegativeInt
|
6 |
+
|
7 |
+
from .survey import SurveyResponse
|
8 |
+
|
9 |
+
|
10 |
+
class TwitterEngagements(BaseModel):
|
11 |
+
"""Engagement counts from Twitter"""
|
12 |
+
|
13 |
+
retweet: NonNegativeInt
|
14 |
+
like: NonNegativeInt
|
15 |
+
comment: NonNegativeInt
|
16 |
+
share: NonNegativeInt
|
17 |
+
|
18 |
+
|
19 |
+
class RedditEngagements(BaseModel):
|
20 |
+
"""Engagement counts from Reddit"""
|
21 |
+
|
22 |
+
upvote: NonNegativeInt
|
23 |
+
downvote: NonNegativeInt
|
24 |
+
comment: NonNegativeInt
|
25 |
+
award: NonNegativeInt
|
26 |
+
|
27 |
+
|
28 |
+
class FacebookEngagements(BaseModel):
|
29 |
+
"""Engagement counts from Facebook"""
|
30 |
+
|
31 |
+
like: NonNegativeInt
|
32 |
+
love: NonNegativeInt
|
33 |
+
care: NonNegativeInt
|
34 |
+
haha: NonNegativeInt
|
35 |
+
wow: NonNegativeInt
|
36 |
+
sad: NonNegativeInt
|
37 |
+
angry: NonNegativeInt
|
38 |
+
comment: NonNegativeInt
|
39 |
+
share: NonNegativeInt
|
40 |
+
|
41 |
+
|
42 |
+
class ContentItem(BaseModel):
|
43 |
+
"""A content item to be ranked"""
|
44 |
+
|
45 |
+
id: str = Field(
|
46 |
+
description="A unique ID describing a specific piece of content. We will do our best to make an ID for a given item persist between requests, but that property is not guaranteed."
|
47 |
+
)
|
48 |
+
|
49 |
+
post_id: Optional[str] = Field(
|
50 |
+
description="The ID of the post to which this comment belongs. Useful for linking comments to their post when comments are shown in a feed. Currently this UX only exists on Facebook.",
|
51 |
+
default=None,
|
52 |
+
)
|
53 |
+
|
54 |
+
parent_id: Optional[str] = Field(
|
55 |
+
description="For threaded comments, this identifies the comment to which this one is a reply. Blank for top-level comments.",
|
56 |
+
default=None,
|
57 |
+
)
|
58 |
+
|
59 |
+
title: Optional[str] = Field(
|
60 |
+
description="The post title, only available on reddit posts.", default=None
|
61 |
+
)
|
62 |
+
|
63 |
+
text: str = Field(
|
64 |
+
description="The text of the content item. Assume UTF-8, and that leading and trailing whitespace have been trimmed."
|
65 |
+
)
|
66 |
+
|
67 |
+
author_name_hash: str = Field(
|
68 |
+
description="A hash of the author's name (salted). Use this to determine which posts are by the same author. When the post is by the current user, this should match `session.user_name_hash`."
|
69 |
+
)
|
70 |
+
|
71 |
+
type: Literal["post", "comment"] = Field(
|
72 |
+
description="Whether the content item is a `post` or `comment`. On Twitter, tweets will be identified as `comment` when they are replies displayed on the page for a single tweet."
|
73 |
+
)
|
74 |
+
|
75 |
+
embedded_urls: Optional[list[HttpUrl]] = Field(
|
76 |
+
description="A list of URLs that are embedded in the content item. This could be links to images, videos, or other content. They may or may not also appear in the text of the item."
|
77 |
+
)
|
78 |
+
|
79 |
+
created_at: datetime = Field(
|
80 |
+
description="The time that the item was created in UTC, in `YYYY-MM-DD hh:mm:ss` format, at the highest resolution available (which may be as low as the hour)."
|
81 |
+
)
|
82 |
+
|
83 |
+
engagements: Union[TwitterEngagements, RedditEngagements, FacebookEngagements] = (
|
84 |
+
Field(description="Engagement counts for the content item.")
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
class Session(BaseModel):
|
89 |
+
"""Data that is scoped to the user's browsing session (generally a single page view)"""
|
90 |
+
|
91 |
+
user_id: str = Field(
|
92 |
+
description="A unique id for this study participant. Will remain fixed for the duration of the experiment."
|
93 |
+
)
|
94 |
+
user_name_hash: str = Field(
|
95 |
+
description="A (salted) hash of the user's username. We'll do our best to make it match the `item.author_name_hash` on posts authored by the current user."
|
96 |
+
)
|
97 |
+
cohort: str = Field(
|
98 |
+
description="The cohort to which the user has been assigned. You can safely ignore this. It is used by the PRC request router."
|
99 |
+
)
|
100 |
+
platform: Literal["twitter", "reddit", "facebook"] = Field(
|
101 |
+
description="The platform on which the user is viewing content."
|
102 |
+
)
|
103 |
+
current_time: datetime = Field(
|
104 |
+
description="The current time according to the user's browser, in UTC, in `YYYY-MM-DD hh:mm:ss` format."
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
class RankingRequest(BaseModel):
|
109 |
+
"""A complete ranking request"""
|
110 |
+
|
111 |
+
session: Session = Field(
|
112 |
+
description="Data that is scoped to the user's browsing session"
|
113 |
+
)
|
114 |
+
survey: Optional[SurveyResponse] = Field(
|
115 |
+
description="Responses to PRC survey. Added by the request router.",
|
116 |
+
default=None,
|
117 |
+
)
|
118 |
+
items: list[ContentItem] = Field(description="The content items to be ranked.")
|
app/_models/requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faker
|
2 |
+
pydantic
|
3 |
+
pytest
|
4 |
+
|
app/_models/response.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from uuid import uuid4
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field, HttpUrl
|
5 |
+
|
6 |
+
|
7 |
+
class NewItem(BaseModel):
|
8 |
+
"""A new item to be inserted into the feed"""
|
9 |
+
|
10 |
+
id: str = Field(
|
11 |
+
description="A unique ID for the content item. You can generate this.",
|
12 |
+
default_factory=uuid4,
|
13 |
+
)
|
14 |
+
url: HttpUrl = Field(description="The publicly-accessible URL of the content item.")
|
15 |
+
|
16 |
+
|
17 |
+
class RankingResponse(BaseModel):
|
18 |
+
"""A response to a ranking request"""
|
19 |
+
|
20 |
+
ranked_ids: list[str] = Field(
|
21 |
+
description="The IDs of the content items, in the order they should be displayed."
|
22 |
+
)
|
23 |
+
new_items: Optional[list[NewItem]] = Field(
|
24 |
+
description="New publicly-accessible items to be inserted into the feed.",
|
25 |
+
default=None,
|
26 |
+
)
|
app/_models/survey.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import IntEnum
|
2 |
+
from typing import Annotated, Literal, Optional
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
|
7 |
+
class IdeologyEnum(IntEnum):
|
8 |
+
"""Enum for ideology"""
|
9 |
+
|
10 |
+
extremely_liberal = 1
|
11 |
+
liberal = 2
|
12 |
+
slightly_liberal = 3
|
13 |
+
moderate = 4
|
14 |
+
slightly_conservative = 5
|
15 |
+
conservative = 6
|
16 |
+
extremely_conservative = 7
|
17 |
+
|
18 |
+
|
19 |
+
class AgeEnum(IntEnum):
|
20 |
+
"""Enum for age brackets"""
|
21 |
+
|
22 |
+
eighteen_to_twenty_four = 1
|
23 |
+
twenty_five_to_thirty_four = 2
|
24 |
+
thirty_five_to_forty_four = 3
|
25 |
+
forty_five_to_fifty_four = 4
|
26 |
+
fifty_five_to_sixty_four = 5
|
27 |
+
sixty_five_and_older = 6
|
28 |
+
|
29 |
+
|
30 |
+
class EducationEnum(IntEnum):
|
31 |
+
"""Enum for education levels"""
|
32 |
+
|
33 |
+
some_high_school = 1
|
34 |
+
graduated_high_school = 2
|
35 |
+
some_college = 3
|
36 |
+
associate_degree = 4
|
37 |
+
bachelor_degree = 5
|
38 |
+
graduate_degree = 6
|
39 |
+
|
40 |
+
|
41 |
+
class IncomeEnum(IntEnum):
|
42 |
+
"""Enum for income brackets"""
|
43 |
+
|
44 |
+
less_than_20k = 1
|
45 |
+
twenty_to_thirty_five_k = 2
|
46 |
+
thirty_five_to_fifty_k = 3
|
47 |
+
fifty_to_seventy_five_k = 4
|
48 |
+
seventy_five_to_one_hundred_k = 5
|
49 |
+
one_hundred_to_one_hundred_fifty_k = 6
|
50 |
+
one_hundred_fifty_k_and_up = 7
|
51 |
+
|
52 |
+
|
53 |
+
class SocmedUseEnum(IntEnum):
|
54 |
+
"""Enum for social media usage"""
|
55 |
+
|
56 |
+
zero_to_thirty_minutes = 1
|
57 |
+
thirty_to_sixty_minutes = 2
|
58 |
+
sixty_to_ninety_minutes = 3
|
59 |
+
ninety_to_one_twenty_minutes = 4
|
60 |
+
two_to_three_hours = 5
|
61 |
+
three_to_four_hours = 6
|
62 |
+
more_than_four_hours = 7
|
63 |
+
|
64 |
+
|
65 |
+
class SurveyResponse(BaseModel):
|
66 |
+
"""
|
67 |
+
Response to PRC survey.
|
68 |
+
|
69 |
+
Scalar quantities are represented as IntEnums, while categorical questions are represented as strings constrained with Literal.
|
70 |
+
|
71 |
+
These will be added to the request by the PRC request router, since this data is not available to the browser extension.
|
72 |
+
"""
|
73 |
+
|
74 |
+
# Demographic metrics
|
75 |
+
party_id: Literal[
|
76 |
+
"democrat", "republican", "independent", "other", "no_preference"
|
77 |
+
] = Field(
|
78 |
+
description="Generally speaking, do you usually think of yourself as a Republican, Democrat, Independent, etc?"
|
79 |
+
)
|
80 |
+
|
81 |
+
party_write_in: Optional[str] = Field(
|
82 |
+
description="If you selected 'other' for your party identification, please specify.",
|
83 |
+
default=None,
|
84 |
+
)
|
85 |
+
|
86 |
+
support: Literal["strong", "not_strong"] = Field(
|
87 |
+
description="Would you call yourself a strong or not a very strong supporter of your party?"
|
88 |
+
)
|
89 |
+
|
90 |
+
party_lean: Literal["democrat", "republican"] = Field(
|
91 |
+
description="Do you think of yourself as closer to the Republican or Democratic party?"
|
92 |
+
)
|
93 |
+
|
94 |
+
sex: Literal["female", "male", "nonbinary", "prefer_not_to_say"]
|
95 |
+
|
96 |
+
age: AgeEnum = Field(description="What age are you?")
|
97 |
+
|
98 |
+
education: EducationEnum = Field(
|
99 |
+
description="What is the highest level of education you have completed?"
|
100 |
+
)
|
101 |
+
|
102 |
+
ideology: IdeologyEnum = Field(
|
103 |
+
description="Here is a scale on which the political views that people might hold are arranged from liberal to conservative. Where would you place yourself on this scale?"
|
104 |
+
)
|
105 |
+
|
106 |
+
income: IncomeEnum = Field(description="What is your annual household income?")
|
107 |
+
|
108 |
+
ethnicity: Literal[
|
109 |
+
"native_american",
|
110 |
+
"asian_or_pacific_islander",
|
111 |
+
"black_or_african_american",
|
112 |
+
"hispanic_or_latino",
|
113 |
+
"white_or_caucasian",
|
114 |
+
"multiple_or_other",
|
115 |
+
] = Field(description="Which race or ethnicity best describes you?")
|
116 |
+
|
117 |
+
ethnicity_write_in: Optional[str] = Field(
|
118 |
+
description="If you selected 'multiple' or 'other' for your ethnicity, please specify.",
|
119 |
+
default=None,
|
120 |
+
)
|
121 |
+
|
122 |
+
socmed_use: SocmedUseEnum = Field(
|
123 |
+
description="Think of the past two weeks. How much time did you spend on social media, on average, per day?"
|
124 |
+
)
|
125 |
+
|
126 |
+
browser_perc: Annotated[
|
127 |
+
float,
|
128 |
+
Field(
|
129 |
+
ge=0,
|
130 |
+
le=1,
|
131 |
+
description="In the last two weeks, what percentage of your social media [twitter/facebook/reddit] has been on a desktop device or laptop?",
|
132 |
+
),
|
133 |
+
]
|
134 |
+
|
135 |
+
mobile_perc: Annotated[
|
136 |
+
float,
|
137 |
+
Field(
|
138 |
+
ge=0,
|
139 |
+
le=1,
|
140 |
+
description="In the last two weeks, what percentage of your social media [twitter/facebook/reddit] has been on mobile device?",
|
141 |
+
),
|
142 |
+
]
|
app/app.py
CHANGED
@@ -3,11 +3,14 @@ from fastapi import FastAPI, HTTPException
|
|
3 |
#import redis
|
4 |
from dotenv import load_dotenv
|
5 |
import os
|
|
|
6 |
|
7 |
from modules.redistribute import redistribute, insert_element_at_position
|
8 |
-
from modules.models.api import Input, Output, NewItem, UUID
|
9 |
from modules.database import BoostDatabase, UserDatabase, User
|
10 |
-
|
|
|
|
|
11 |
|
12 |
# Load environment variables from .env file
|
13 |
load_dotenv('../.env')
|
@@ -17,7 +20,10 @@ redis_port = os.getenv("REDIS_PORT")
|
|
17 |
fastapi_port = os.getenv("FASTAPI_PORT")
|
18 |
|
19 |
|
20 |
-
print("
|
|
|
|
|
|
|
21 |
print("FastAPI port:", fastapi_port)
|
22 |
|
23 |
app = FastAPI()
|
@@ -31,12 +37,12 @@ async def health_check():
|
|
31 |
return {"status": "ok"}
|
32 |
|
33 |
# Define FastAPI routes and logic
|
34 |
-
@app.post("/
|
35 |
-
async def rerank_items(input_data:
|
36 |
# who is the user?
|
37 |
user = input_data.session.user_id
|
38 |
date = input_data.session.current_time
|
39 |
-
platform = input_data.session.platform
|
40 |
items = input_data.items
|
41 |
# TODO consider sampling them?
|
42 |
|
@@ -54,7 +60,7 @@ async def rerank_items(input_data: Input) -> Output:
|
|
54 |
print(user_in_db)
|
55 |
if user_in_db.is_boosted_today():
|
56 |
# return only reranked items, no insertion
|
57 |
-
return
|
58 |
# user exists and not boosted today yet
|
59 |
else:
|
60 |
new_items = []
|
@@ -72,11 +78,11 @@ async def rerank_items(input_data: Input) -> Output:
|
|
72 |
element=UUID(fetched_boost['id']),
|
73 |
position=insertion_pos)
|
74 |
|
75 |
-
return
|
76 |
|
77 |
# no civic content to boost on
|
78 |
else:
|
79 |
-
return
|
80 |
|
81 |
# user doesn't exist
|
82 |
else:
|
@@ -90,14 +96,14 @@ async def rerank_items(input_data: Input) -> Output:
|
|
90 |
|
91 |
# insert boost before first civic in batch
|
92 |
reranked_ids = insert_element_at_position(lst=reranked_ids,
|
93 |
-
element=
|
94 |
position=insertion_pos)
|
95 |
|
96 |
|
97 |
-
return
|
98 |
|
99 |
# no civic content to boost on
|
100 |
else:
|
101 |
print("there")
|
102 |
-
return
|
103 |
|
|
|
3 |
#import redis
|
4 |
from dotenv import load_dotenv
|
5 |
import os
|
6 |
+
import torch
|
7 |
|
8 |
from modules.redistribute import redistribute, insert_element_at_position
|
9 |
+
#from modules.models.api import Input, Output, NewItem, UUID
|
10 |
from modules.database import BoostDatabase, UserDatabase, User
|
11 |
+
from _models.request import RankingRequest
|
12 |
+
from _models.response import RankingResponse, NewItem
|
13 |
+
from modules.models.api import UUID
|
14 |
|
15 |
# Load environment variables from .env file
|
16 |
load_dotenv('../.env')
|
|
|
20 |
fastapi_port = os.getenv("FASTAPI_PORT")
|
21 |
|
22 |
|
23 |
+
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
24 |
+
#print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
25 |
+
|
26 |
+
#print("Redis port:", redis_port)
|
27 |
print("FastAPI port:", fastapi_port)
|
28 |
|
29 |
app = FastAPI()
|
|
|
37 |
return {"status": "ok"}
|
38 |
|
39 |
# Define FastAPI routes and logic
|
40 |
+
@app.post("/rank")
|
41 |
+
async def rerank_items(input_data: RankingRequest) -> RankingResponse:
|
42 |
# who is the user?
|
43 |
user = input_data.session.user_id
|
44 |
date = input_data.session.current_time
|
45 |
+
platform = input_data.session.platform
|
46 |
items = input_data.items
|
47 |
# TODO consider sampling them?
|
48 |
|
|
|
60 |
print(user_in_db)
|
61 |
if user_in_db.is_boosted_today():
|
62 |
# return only reranked items, no insertion
|
63 |
+
return RankingResponse(ranked_ids=reranked_ids, new_items=[])
|
64 |
# user exists and not boosted today yet
|
65 |
else:
|
66 |
new_items = []
|
|
|
78 |
element=UUID(fetched_boost['id']),
|
79 |
position=insertion_pos)
|
80 |
|
81 |
+
return RankingResponse(ranked_ids=reranked_ids, new_items=[NewItem(id=fetched_boost["id"], url=fetched_boost["url"])])
|
82 |
|
83 |
# no civic content to boost on
|
84 |
else:
|
85 |
+
return RankingResponse(ranked_ids=reranked_ids, new_items=[])
|
86 |
|
87 |
# user doesn't exist
|
88 |
else:
|
|
|
96 |
|
97 |
# insert boost before first civic in batch
|
98 |
reranked_ids = insert_element_at_position(lst=reranked_ids,
|
99 |
+
element=fetched_boost['id'],
|
100 |
position=insertion_pos)
|
101 |
|
102 |
|
103 |
+
return RankingResponse(ranked_ids=reranked_ids, new_items=[NewItem(id=fetched_boost["id"], url=fetched_boost["url"])])
|
104 |
|
105 |
# no civic content to boost on
|
106 |
else:
|
107 |
print("there")
|
108 |
+
return RankingResponse(ranked_ids=reranked_ids, new_items=[])
|
109 |
|
app/modules/classify.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
from transformers import
|
2 |
from typing import List
|
3 |
|
4 |
-
model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
|
|
|
|
5 |
|
6 |
|
7 |
label_map = {
|
|
|
1 |
+
from transformers import pipeline
|
2 |
from typing import List
|
3 |
|
4 |
+
#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
5 |
+
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)
|
6 |
+
|
7 |
|
8 |
|
9 |
label_map = {
|