Skip to content

Commit 4221cf7

Browse files
fix(nodes): fix schema generation for output classes
All output classes need to have their properties flagged as `required` for the schema generation to work as needed.
1 parent c34ac91 commit 4221cf7

3 files changed

Lines changed: 46 additions & 1 deletion

File tree

invokeai/app/invocations/image.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput):
2828
image: ImageField = Field(default=None, description="The output image")
2929
#fmt: on
3030

31+
class Config:
32+
schema_extra = {
33+
'required': [
34+
'type',
35+
'image',
36+
]
37+
}
38+
3139
class MaskOutput(BaseInvocationOutput):
3240
"""Base class for invocations that output a mask"""
3341
#fmt: off
3442
type: Literal["mask"] = "mask"
3543
mask: ImageField = Field(default=None, description="The output mask")
36-
#fomt: on
44+
#fmt: on
45+
46+
class Config:
47+
schema_extra = {
48+
'required': [
49+
'type',
50+
'mask',
51+
]
52+
}
3753

3854
# TODO: this isn't really necessary anymore
3955
class LoadImageInvocation(BaseInvocation):

invokeai/app/invocations/prompt.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput):
1212

1313
prompt: str = Field(default=None, description="The output prompt")
1414
#fmt: on
15+
16+
class Config:
17+
schema_extra = {
18+
'required': [
19+
'type',
20+
'prompt',
21+
]
22+
}

invokeai/app/services/graph.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception):
127127
class GraphInvocationOutput(BaseInvocationOutput):
128128
type: Literal["graph_output"] = "graph_output"
129129

130+
class Config:
131+
schema_extra = {
132+
'required': [
133+
'type',
134+
'image',
135+
]
136+
}
130137

131138
# TODO: Fill this out and move to invocations
132139
class GraphInvocation(BaseInvocation):
@@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput):
147154

148155
item: Any = Field(description="The item being iterated over")
149156

157+
class Config:
158+
schema_extra = {
159+
'required': [
160+
'type',
161+
'item',
162+
]
163+
}
150164

151165
# TODO: Fill this out and move to invocations
152166
class IterateInvocation(BaseInvocation):
@@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
169183

170184
collection: list[Any] = Field(description="The collection of input items")
171185

186+
class Config:
187+
schema_extra = {
188+
'required': [
189+
'type',
190+
'collection',
191+
]
192+
}
172193

173194
class CollectInvocation(BaseInvocation):
174195
"""Collects values into a collection"""

0 commit comments

Comments
 (0)