Skip to content

Commit 2e1cd04

Browse files
committed
Allow import of models without project.json
1 parent 8c15533 commit 2e1cd04

7 files changed

+207
-37
lines changed

lib/importers/geti_deployment.dart

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import 'package:archive/archive_io.dart';
1010
import 'package:inference/interop/utils.dart';
1111
import 'package:inference/importers/importer.dart';
1212
import 'package:inference/project.dart';
13+
import 'package:inference/utils.dart';
1314
import 'package:path/path.dart';
1415
import 'package:path_provider/path_provider.dart';
1516
import 'package:uuid/uuid.dart';
@@ -63,7 +64,7 @@ class GetiDeploymentProcessor extends Importer {
6364
await processTask(task);
6465
}
6566
const encoder = JsonEncoder.withIndent(" ");
66-
project!.size = project!.calculateDiskUsage();
67+
project!.size = calculateDiskUsage(project!.storagePath);
6768
File(platformContext.join(project!.storagePath, "project.json"))
6869
.writeAsString(encoder.convert(project!.toMap()));
6970
project!.loaded.complete();

lib/importers/model_directory_importer.dart

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class ModelDirImporter extends Importer {
3333
return project!;
3434
}
3535

36+
bool get containsProjectJson => findProjectJsonFile() != null;
37+
3638
String? findProjectJsonFile() {
3739
final files = Directory(directory).listSync();
3840
return files.firstWhereOrNull((file) => basename(file.path) == "project.json")?.path;

lib/pages/import/widgets/directory_import.dart

+165-19
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,15 @@ import 'package:file_picker/file_picker.dart';
99
import 'package:fluent_ui/fluent_ui.dart';
1010
import 'package:flutter_svg/svg.dart';
1111
import 'package:inference/importers/model_directory_importer.dart';
12+
import 'package:inference/importers/model_manifest.dart';
13+
import 'package:inference/project.dart';
1214
import 'package:inference/providers/project_provider.dart';
15+
import 'package:inference/public_models.dart';
1316
import 'package:inference/theme_fluent.dart';
17+
import 'package:inference/utils.dart';
1418
import 'package:path/path.dart' show dirname;
1519
import 'package:provider/provider.dart';
20+
import 'package:uuid/uuid.dart';
1621

1722
class DirectoryImport extends StatefulWidget {
1823
const DirectoryImport({super.key});
@@ -23,6 +28,7 @@ class DirectoryImport extends StatefulWidget {
2328

2429
class _DirectoryImportState extends State<DirectoryImport> {
2530
bool _showReleaseMessage = false;
31+
String? selectedFolder;
2632

2733
void showReleaseMessage() {
2834
setState(() => _showReleaseMessage = true);
@@ -43,7 +49,25 @@ class _DirectoryImportState extends State<DirectoryImport> {
4349

4450
Future<void> processFolder(String directory) async {
4551
final importer = ModelDirImporter(directory);
46-
final project = await importer.generateProject();
52+
if (importer.containsProjectJson) {
53+
final project = await importer.generateProject();
54+
importer.setupFiles();
55+
if (mounted) {
56+
final projectsProvider = Provider.of<ProjectProvider>(context, listen: false);
57+
projectsProvider.addProject(project);
58+
Navigator.pop(context, [project]);
59+
}
60+
} else {
61+
setState(() {
62+
selectedFolder = directory;
63+
});
64+
}
65+
}
66+
67+
Future<void> importProject(Project project) async {
68+
final importer = ModelDirImporter(project.storagePath);
69+
importer.project = project;
70+
writeProjectJson(project);
4771
importer.setupFiles();
4872
if (mounted) {
4973
final projectsProvider = Provider.of<ProjectProvider>(context, listen: false);
@@ -78,31 +102,153 @@ class _DirectoryImportState extends State<DirectoryImport> {
78102
},
79103
onDragExited: (val) => hideReleaseMessage(),
80104
onDragEntered: (val) => showReleaseMessage(),
105+
child: Builder(
106+
builder: (context) {
107+
if (selectedFolder != null && !_showReleaseMessage) {
108+
return ModelImportPropertiesForm(
109+
storagePath: selectedFolder!,
110+
onProjectImport: importProject,
111+
onCancel: () {
112+
setState(() => selectedFolder = null);
113+
},
114+
);
115+
}
116+
return Column(
117+
mainAxisAlignment: MainAxisAlignment.center,
118+
crossAxisAlignment: CrossAxisAlignment.center,
119+
children: [
120+
Text(text,
121+
style: const TextStyle(
122+
fontSize: 28,
123+
)
124+
),
125+
Padding(
126+
padding: const EdgeInsets.all(24),
127+
child: SvgPicture.asset('images/drop_geti.svg'),
128+
),
129+
Padding(
130+
padding: const EdgeInsets.only(top: 24),
131+
child: FilledButton(
132+
onPressed: () => selectFolder(),
133+
child: const Text("Select folder"),
134+
),
135+
)
136+
],
137+
);
138+
}
139+
),
140+
);
141+
}
142+
)
143+
);
144+
}
145+
}
146+
147+
class ModelImportPropertiesForm extends StatefulWidget {
148+
final String storagePath;
149+
final Function? onCancel;
150+
final Function(Project)? onProjectImport;
151+
const ModelImportPropertiesForm({
152+
super.key,
153+
required this.storagePath,
154+
this.onProjectImport,
155+
this.onCancel,
156+
});
157+
158+
@override
159+
State<ModelImportPropertiesForm> createState() => _ModelImportPropertiesFormState();
160+
}
161+
162+
class _ModelImportPropertiesFormState extends State<ModelImportPropertiesForm> {
163+
ProjectType projectType = ProjectType.text;
164+
final TextEditingController _controller = TextEditingController();
165+
166+
void generateModelManifest() async {
167+
final name = _controller.text;
168+
final manifest = ModelManifest(
169+
name: name,
170+
id: const Uuid().v4().toString(),
171+
fileSize: calculateDiskUsage(widget.storagePath),
172+
task: projectTypeToString(projectType),
173+
author: "Unknown",
174+
collection: "",
175+
description: "Try out $name",
176+
npuEnabled: false,
177+
contextWindow: 0,
178+
optimizationPrecision: "",
179+
);
180+
181+
widget.onProjectImport?.call(await PublicProject.fromModelManifest(manifest, widget.storagePath));
182+
}
183+
184+
bool get validForm => _controller.text.isNotEmpty;
185+
186+
@override
187+
void dispose() {
188+
_controller.dispose();
189+
super.dispose();
190+
}
191+
192+
@override
193+
Widget build(BuildContext context) {
194+
return Padding(
195+
padding: const EdgeInsets.all(20),
196+
child: Column(
197+
crossAxisAlignment: CrossAxisAlignment.start,
198+
children: [
199+
Expanded(
81200
child: Column(
82-
mainAxisAlignment: MainAxisAlignment.center,
83-
crossAxisAlignment: CrossAxisAlignment.center,
201+
crossAxisAlignment: CrossAxisAlignment.start,
84202
children: [
85-
Text(text,
86-
style: const TextStyle(
87-
fontSize: 28,
203+
const SizedBox(height: 10),
204+
InfoLabel(
205+
label: "Model name",
206+
child: TextBox(
207+
placeholder: "Name",
208+
controller: _controller,
209+
onChanged: (_) => setState(() {}),
88210
)
89211
),
90-
Padding(
91-
padding: const EdgeInsets.all(24),
92-
child: SvgPicture.asset('images/drop_geti.svg'),
212+
const SizedBox(height: 10),
213+
InfoLabel(
214+
label: "Task",
215+
child: ComboBox<ProjectType>(
216+
value: projectType,
217+
items: List<ComboBoxItem<ProjectType>>.from([
218+
ProjectType.text,
219+
ProjectType.vlm,
220+
ProjectType.textToImage,
221+
ProjectType.speech,
222+
].map((t) {
223+
return ComboBoxItem<ProjectType>(
224+
value: t,
225+
child: Text(projectTypeToName(t)),
226+
);
227+
})),
228+
placeholder: const Text("Select model task"),
229+
onChanged: (t) => setState(() => projectType = t!),
230+
)
93231
),
94-
Padding(
95-
padding: const EdgeInsets.only(top: 24),
96-
child: FilledButton(
97-
onPressed: () => selectFolder(),
98-
child: const Text("Select folder"),
99-
),
100-
)
101232
],
102233
),
103-
);
104-
}
105-
)
234+
),
235+
Row(
236+
mainAxisAlignment: MainAxisAlignment.end,
237+
children: [
238+
Button(
239+
onPressed: () => widget.onCancel?.call(),
240+
child: const Text("Cancel"),
241+
),
242+
const SizedBox(width: 10),
243+
FilledButton(
244+
onPressed: validForm ? () => generateModelManifest() : null,
245+
child: const Text("Import"),
246+
),
247+
],
248+
)
249+
250+
],
251+
),
106252
);
107253
}
108254
}

lib/pages/import/widgets/import_model_dialog.dart

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class _ImportModelDialogState extends State<ImportModelDialog> {
4848
child: Row(
4949
mainAxisAlignment: MainAxisAlignment.spaceBetween,
5050
children: [
51-
const Text('Import model from your local disk'),
51+
const Text('Import model from your local disk', style: TextStyle(fontSize: 20)),
5252
IconButton(
5353
icon: const Icon(FluentIcons.clear),
5454
onPressed: () { Navigator.pop(context, List<Project>.from([])); },
@@ -66,12 +66,12 @@ class _ImportModelDialogState extends State<ImportModelDialog> {
6666
displayMode: PaneDisplayMode.top,
6767
items: [
6868
PaneItem(
69-
icon: Icon(FluentIcons.car),
69+
icon: Icon(FluentIcons.add_to),
7070
title: const Text("Geti"),
7171
body: GetiImport(),
7272
),
7373
PaneItem(
74-
icon: Icon(FluentIcons.car),
74+
icon: Icon(FluentIcons.fabric_folder_upload),
7575
title: const Text("Directory"),
7676
body: DirectoryImport(),
7777
),

lib/project.dart

+26-13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import 'dart:io';
77

88
import 'package:fluent_ui/fluent_ui.dart';
99
import 'package:inference/importers/model_manifest.dart';
10+
import 'package:inference/utils.dart';
1011
import 'package:inference/utils/get_public_thumbnail.dart';
1112
import 'package:path/path.dart';
1213
import 'package:collection/collection.dart';
@@ -147,6 +148,21 @@ String projectTypeToString(ProjectType type) {
147148
}
148149
}
149150

151+
String projectTypeToName(ProjectType type) {
152+
switch(type){
153+
case ProjectType.text:
154+
return "Text Generation";
155+
case ProjectType.textToImage:
156+
return "Image generation";
157+
case ProjectType.vlm:
158+
return "VLM";
159+
case ProjectType.image:
160+
return "Computer Vision";
161+
case ProjectType.speech:
162+
return "Transcription";
163+
}
164+
}
165+
150166
class Project {
151167
String id;
152168
String modelId;
@@ -217,15 +233,7 @@ class GetiProject extends Project {
217233

218234
GetiProject(String id, String modelId, String applicationVersion, String name, String creationTime, ProjectType type, String storagePath)
219235
: super(id, modelId, applicationVersion, name, creationTime, type, storagePath, false) {
220-
size = calculateDiskUsage();
221-
}
222-
223-
int calculateDiskUsage() {
224-
final dir = Directory(storagePath);
225-
if (dir.existsSync()) {
226-
return dir.listSync(recursive: true).fold(0, (acc, m) => acc + m.statSync().size);
227-
}
228-
return 0;
236+
size = calculateDiskUsage(storagePath);
229237
}
230238

231239
List<Score?> scores() {
@@ -347,7 +355,7 @@ class PublicProject extends Project {
347355
}
348356

349357
Image get thumbnail {
350-
return getThumbnail(modelId);
358+
return getThumbnail(name);
351359
}
352360

353361
@override
@@ -368,11 +376,16 @@ class PublicProject extends Project {
368376
@override
369377
bool get isDownloaded => loaded.isCompleted;
370378

371-
static Future<PublicProject> fromModelManifest(ModelManifest manifest) async {
379+
static Future<PublicProject> fromModelManifest(ModelManifest manifest, [String? path]) async {
372380
final directory = await getApplicationSupportDirectory();
373381
final projectId = manifest.id;
374-
final storagePath = platformContext.join(directory.path, projectId.toString());
375-
await Directory(storagePath).create(recursive: true);
382+
String storagePath;
383+
if (path == null) {
384+
storagePath = platformContext.join(directory.path, projectId.toString());
385+
await Directory(storagePath).create(recursive: true);
386+
} else {
387+
storagePath = path;
388+
}
376389
final projectType = parseProjectType(manifest.task);
377390

378391
return PublicProject(

lib/public_models.dart

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void createDirectory(PublicProject project) {
1919
Directory(project.storagePath).createSync();
2020
}
2121

22-
void writeProjectJson(PublicProject project) {
22+
void writeProjectJson(Project project) {
2323
final projectFile = platformContext.join(project.storagePath, "project.json");
2424
const encoder = JsonEncoder.withIndent(" ");
2525
File(projectFile).writeAsStringSync(encoder.convert((project.toMap())));

lib/utils.dart

+8
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,11 @@ class Envvars {
126126
return '';
127127
}
128128
}
129+
130+
int calculateDiskUsage(path) {
131+
final dir = Directory(path);
132+
if (dir.existsSync()) {
133+
return dir.listSync(recursive: true).fold(0, (acc, m) => acc + m.statSync().size);
134+
}
135+
return 0;
136+
}

0 commit comments

Comments
 (0)