Spaces:
Sleeping
Sleeping
File size: 3,378 Bytes
05140a3 7eaefa4 5f079a9 7eaefa4 c0bb523 7eaefa4 cf17b13 7eaefa4 cf17b13 c31c636 7eaefa4 c31c636 7eaefa4 c31c636 7eaefa4 c31c636 7eaefa4 cf17b13 7eaefa4 cf17b13 7eaefa4 8f687fb 7eaefa4 c31c636 7eaefa4 7b9061c ee7bc16 24cc048 ee7bc16 41d1888 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import gradio as gr
import json
from transformers import pipeline
# Load the pipeline
tourModel = pipeline(model="manhan/GPT-Tour", device=0)
def getTour(income,size,years,sex,edu,wrk):
# person = a dict with person-level and hh-level attributes:
person = {}
hh_income = int(income)
if hh_income<25000: # $25,000
person['hh_inc'] = 'poor'
elif hh_income<50000: # $50,000
person['hh_inc'] = 'low'
elif hh_income<75000: # $75,000
person['hh_inc'] = 'medium'
elif hh_income<125000: # $125,000
person['hh_inc'] = 'high'
else: # over
person['hh_inc'] = 'affluent'
hh_size = int(size)
if hh_size == 1:
person['hh_size'] = 'single'
elif hh_size == 2:
person['hh_size'] = 'couple'
elif hh_size <= 4:
person['hh_size'] = 'small'
else: # more than four people
person['hh_size'] = 'large'
age = int(years)
if age < 18:
person['age_grp'] = 'child'
elif age < 45:
person['age_grp'] = 'younger'
elif age < 65:
person['age_grp'] = 'older'
else:
person['age_grp'] = 'senior'
person['sex'] = sex
person['edu'] = edu
person['wrk'] = wrk
activity_list = []
prompt = json.dumps(person)[:-1] + ", pattern: "
print(person)
while not activity_list:
generated = tourModel(prompt, return_full_text=False, max_length=250, temperature=0.9)[0]['generated_text']
#print(f"{generated}")
start_pos = generated.find('[')
end_pos = generated.find(']')+1
activity_list_str = generated[start_pos:end_pos]
print(f"Extracted: '{activity_list_str}'")
#if person['wrk']=='yes' and activity_list_str.find('Work')==-1:
# continue # try again
#if person['wrk']=='no' and activity_list_str.find('Work')>0:
# continue # try again
if activity_list_str:
try:
activity_list = json.loads(activity_list_str)
if activity_list[-1]!='Home':
activity_list=[]
continue
break
except Exception as e:
print("Error parsing activity list")
print(e)
else:
print("Nothing extracted!")
return activity_list
with gr.Interface(fn=getTour, inputs=[
gr.Markdown("# Activity Sequence Modeling Using Transformers"),
gr.Markdown("This demo uses a GPT-2 model fine-tuned on activity sequences from the 2017 National Household Travel Survey (NHTS) to generate activity sequences for a given person. The model is trained to predict the next activity in a sequence given the previous activities, and is conditioned on person-level and household-level attributes."),
gr.Textbox(label="Annual Household Income (in dollars)"),
gr.Textbox(label="Household Size (number of people)"),
gr.Textbox(label="Traveler Age (years)"),
gr.Dropdown(["unknown", "male", "female"], label="Gender/sex"),
gr.Dropdown(["unknown", "grade school","highschool", "associates", "bachelors", "graduate"], label="Educational attainment level"),
gr.Dropdown(["unknown", "yes","no"], label="Worker status")],
outputs=["json"], title="GPT-Travel", description="Author: Colby Brown, Manhan ([email protected])", allow_flagging='never') as iface:
iface.launch() |