Spaces:
Running
Running
File size: 3,147 Bytes
05140a3 7eaefa4 69889de 7eaefa4 cf17b13 7eaefa4 cf17b13 7eaefa4 cf17b13 7eaefa4 cf17b13 7eaefa4 7b9061c 7eaefa4 7b9061c 7fe3a25 05140a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
import json
from transformers import pipeline
# Load the pipeline
tourModel = pipeline(model="manhan/GPT-Tour")
def getTour(income,size,years,sex,edu,wrk):
# person = a dict with person-level and hh-level attributes:
person = {}
hh_income = int(income)
if hh_income<0:
_skip = 1 # nvm, bad data
elif hh_income<4: # $25,000
person['hh_inc'] = 'poor'
elif hh_income<6: # $50,000
person['hh_inc'] = 'low'
elif hh_income<7: # $75,000
person['hh_inc'] = 'medium'
elif hh_income<9: # $125,000
person['hh_inc'] = 'high'
else: # over
person['hh_inc'] = 'affluent'
hh_size = int(size)
if hh_size == 1:
person['hh_size'] = 'single'
elif hh_size == 2:
person['hh_size'] = 'couple'
elif hh_size <= 4:
person['hh_size'] = 'small'
else: # more than four people
person['hh_size'] = 'large'
age = int(years)
if age < 18:
person['age_grp'] = 'child'
elif age < 45:
person['age_grp'] = 'younger'
elif age < 65:
person['age_grp'] = 'older'
else:
person['age_grp'] = 'senior'
person['sex'] = sex
person['edu'] = edu
person['wrk'] = wrk
activity_list = []
prompt = json.dumps(person)[:-1] + ", pattern: "
print(person)
while not activity_list:
generated = tourModel(prompt, return_full_text=False, max_length=250, temperature=0.9)[0]['generated_text']
#print(f"{generated}")
start_pos = generated.find('[')
end_pos = generated.find(']')+1
activity_list_str = generated[start_pos:end_pos]
print(f"Extracted: '{activity_list_str}'")
# this check doesn't appear to work anyways
if person['wrk']=='yes' and activity_list_str.find('Work')==-1:
continue # try again
if person['wrk']=='no' and activity_list_str.find('Work')>0:
continue # try again
if activity_list_str:
try:
activity_list = json.loads(activity_list_str)
if activity_list[-1]!='Home':
activity_list=[]
continue
break
except Exception as e:
print("Error parsing activity list")
print(e)
else:
print("Nothing extracted!")
return activity_list
with gr.Interface(fn=getTour, inputs=[
gr.inputs.Textbox(label="Annual Household Income (in dollars)"),
gr.inputs.Textbox(label="Household Size (number of people)"),
gr.inputs.Textbox(label="Traveler Age (years)"),
gr.inputs.Dropdown(["unknown", "male", "female"], label="Gender/sex"),
gr.inputs.Dropdown(["unknown", "grade school","highschool", "associates", "bachelors", "graduate"], label="Educational attainment level"),
gr.inputs.Dropdown(["unknown", "yes","no"], label="Worker status")
], outputs="json", title="GPT-Tour", description="Generate a tour for a person", allow_flagging=False, allow_screenshot=False, allow_embedding=False) as iface:
iface.launch()
gr.Interface.load("models/manhan/GPT-Tour").launch() |